[flang-commits] [flang] [NFC][flang][OpenMP] Split `DataSharing` and `Clause` processors (PR #81973)
Kareem Ergawy via flang-commits
flang-commits at lists.llvm.org
Wed Feb 21 04:30:46 PST 2024
https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/81973
>From d277f105a946268182bbcf516191229b7be528a1 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Thu, 15 Feb 2024 06:06:37 -0600
Subject: [PATCH] [NFC][flang][OpenMP] Split `DataSharing` and `Clause`
processors
This started as an experiment to reduce the compilation time of
iterating over `Lower/OpenMP.cpp` a bit since it is too slow at the
moment. Trying to do that, I split the `DataSharingProcessor` and
`ClauseProcessor` into their own files and extracted some shared code
into a util file.
This resulted is a slightly better orgnaization of the OpenMP lowering
code and hence opening this NFC.
As for the compilation time, this unfortunately does not affect it much
(it shaves off a few seconds of `OpenMP.cpp` compilation) since from
what I learned the bottleneck is in `DirectivesCommon.h` and
`PFTBuilder.h` which both consume a lot of time in template
instantiation it seems.
---
flang/lib/Lower/CMakeLists.txt | 6 +-
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 880 +++++++
flang/lib/Lower/OpenMP/ClauseProcessor.h | 305 +++
.../lib/Lower/OpenMP/DataSharingProcessor.cpp | 350 +++
flang/lib/Lower/OpenMP/DataSharingProcessor.h | 89 +
flang/lib/Lower/{ => OpenMP}/OpenMP.cpp | 2040 +----------------
flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 431 ++++
flang/lib/Lower/OpenMP/ReductionProcessor.h | 138 ++
flang/lib/Lower/OpenMP/Utils.cpp | 99 +
flang/lib/Lower/OpenMP/Utils.h | 68 +
10 files changed, 2371 insertions(+), 2035 deletions(-)
create mode 100644 flang/lib/Lower/OpenMP/ClauseProcessor.cpp
create mode 100644 flang/lib/Lower/OpenMP/ClauseProcessor.h
create mode 100644 flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
create mode 100644 flang/lib/Lower/OpenMP/DataSharingProcessor.h
rename flang/lib/Lower/{ => OpenMP}/OpenMP.cpp (55%)
create mode 100644 flang/lib/Lower/OpenMP/ReductionProcessor.cpp
create mode 100644 flang/lib/Lower/OpenMP/ReductionProcessor.h
create mode 100644 flang/lib/Lower/OpenMP/Utils.cpp
create mode 100644 flang/lib/Lower/OpenMP/Utils.h
diff --git a/flang/lib/Lower/CMakeLists.txt b/flang/lib/Lower/CMakeLists.txt
index b13d415e02f1d9..5577a60f1daeac 100644
--- a/flang/lib/Lower/CMakeLists.txt
+++ b/flang/lib/Lower/CMakeLists.txt
@@ -24,7 +24,11 @@ add_flang_library(FortranLower
LoweringOptions.cpp
Mangler.cpp
OpenACC.cpp
- OpenMP.cpp
+ OpenMP/ClauseProcessor.cpp
+ OpenMP/DataSharingProcessor.cpp
+ OpenMP/OpenMP.cpp
+ OpenMP/ReductionProcessor.cpp
+ OpenMP/Utils.cpp
PFTBuilder.cpp
Runtime.cpp
SymbolMap.cpp
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
new file mode 100644
index 00000000000000..4e3951492fb65b
--- /dev/null
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -0,0 +1,880 @@
+//===-- ClauseProcessor.cpp -------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#include "ClauseProcessor.h"
+
+#include "flang/Lower/PFTBuilder.h"
+#include "flang/Parser/tools.h"
+#include "flang/Semantics/tools.h"
+
+namespace Fortran {
+namespace lower {
+namespace omp {
+
+/// Check for unsupported map operand types.
+static void checkMapType(mlir::Location location, mlir::Type type) {
+ if (auto refType = type.dyn_cast<fir::ReferenceType>())
+ type = refType.getElementType();
+ if (auto boxType = type.dyn_cast_or_null<fir::BoxType>())
+ if (!boxType.getElementType().isa<fir::PointerType>())
+ TODO(location, "OMPD_target_data MapOperand BoxType");
+}
+
+static mlir::omp::ScheduleModifier
+translateScheduleModifier(const Fortran::parser::OmpScheduleModifierType &m) {
+ switch (m.v) {
+ case Fortran::parser::OmpScheduleModifierType::ModType::Monotonic:
+ return mlir::omp::ScheduleModifier::monotonic;
+ case Fortran::parser::OmpScheduleModifierType::ModType::Nonmonotonic:
+ return mlir::omp::ScheduleModifier::nonmonotonic;
+ case Fortran::parser::OmpScheduleModifierType::ModType::Simd:
+ return mlir::omp::ScheduleModifier::simd;
+ }
+ return mlir::omp::ScheduleModifier::none;
+}
+
+static mlir::omp::ScheduleModifier
+getScheduleModifier(const Fortran::parser::OmpScheduleClause &x) {
+ const auto &modifier =
+ std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t);
+ // The input may have the modifier any order, so we look for one that isn't
+ // SIMD. If modifier is not set at all, fall down to the bottom and return
+ // "none".
+ if (modifier) {
+ const auto &modType1 =
+ std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t);
+ if (modType1.v.v ==
+ Fortran::parser::OmpScheduleModifierType::ModType::Simd) {
+ const auto &modType2 = std::get<
+ std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>(
+ modifier->t);
+ if (modType2 &&
+ modType2->v.v !=
+ Fortran::parser::OmpScheduleModifierType::ModType::Simd)
+ return translateScheduleModifier(modType2->v);
+
+ return mlir::omp::ScheduleModifier::none;
+ }
+
+ return translateScheduleModifier(modType1.v);
+ }
+ return mlir::omp::ScheduleModifier::none;
+}
+
+static mlir::omp::ScheduleModifier
+getSimdModifier(const Fortran::parser::OmpScheduleClause &x) {
+ const auto &modifier =
+ std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t);
+ // Either of the two possible modifiers in the input can be the SIMD modifier,
+ // so look in either one, and return simd if we find one. Not found = return
+ // "none".
+ if (modifier) {
+ const auto &modType1 =
+ std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t);
+ if (modType1.v.v == Fortran::parser::OmpScheduleModifierType::ModType::Simd)
+ return mlir::omp::ScheduleModifier::simd;
+
+ const auto &modType2 = std::get<
+ std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>(
+ modifier->t);
+ if (modType2 && modType2->v.v ==
+ Fortran::parser::OmpScheduleModifierType::ModType::Simd)
+ return mlir::omp::ScheduleModifier::simd;
+ }
+ return mlir::omp::ScheduleModifier::none;
+}
+
+static void
+genAllocateClause(Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::OmpAllocateClause &ompAllocateClause,
+ llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
+ llvm::SmallVectorImpl<mlir::Value> &allocateOperands) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::Location currentLocation = converter.getCurrentLocation();
+ Fortran::lower::StatementContext stmtCtx;
+
+ mlir::Value allocatorOperand;
+ const Fortran::parser::OmpObjectList &ompObjectList =
+ std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t);
+ const auto &allocateModifier = std::get<
+ std::optional<Fortran::parser::OmpAllocateClause::AllocateModifier>>(
+ ompAllocateClause.t);
+
+ // If the allocate modifier is present, check if we only use the allocator
+ // submodifier. ALIGN in this context is unimplemented
+ const bool onlyAllocator =
+ allocateModifier &&
+ std::holds_alternative<
+ Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
+ allocateModifier->u);
+
+ if (allocateModifier && !onlyAllocator) {
+ TODO(currentLocation, "OmpAllocateClause ALIGN modifier");
+ }
+
+ // Check if allocate clause has allocator specified. If so, add it
+ // to list of allocators, otherwise, add default allocator to
+ // list of allocators.
+ if (onlyAllocator) {
+ const auto &allocatorValue = std::get<
+ Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
+ allocateModifier->u);
+ allocatorOperand = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(allocatorValue.v), stmtCtx));
+ allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
+ allocatorOperand);
+ } else {
+ allocatorOperand = firOpBuilder.createIntegerConstant(
+ currentLocation, firOpBuilder.getI32Type(), 1);
+ allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
+ allocatorOperand);
+ }
+ genObjectList(ompObjectList, converter, allocateOperands);
+}
+
+static mlir::omp::ClauseProcBindKindAttr genProcBindKindAttr(
+ fir::FirOpBuilder &firOpBuilder,
+ const Fortran::parser::OmpClause::ProcBind *procBindClause) {
+ mlir::omp::ClauseProcBindKind procBindKind;
+ switch (procBindClause->v.v) {
+ case Fortran::parser::OmpProcBindClause::Type::Master:
+ procBindKind = mlir::omp::ClauseProcBindKind::Master;
+ break;
+ case Fortran::parser::OmpProcBindClause::Type::Close:
+ procBindKind = mlir::omp::ClauseProcBindKind::Close;
+ break;
+ case Fortran::parser::OmpProcBindClause::Type::Spread:
+ procBindKind = mlir::omp::ClauseProcBindKind::Spread;
+ break;
+ case Fortran::parser::OmpProcBindClause::Type::Primary:
+ procBindKind = mlir::omp::ClauseProcBindKind::Primary;
+ break;
+ }
+ return mlir::omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(),
+ procBindKind);
+}
+
+static mlir::omp::ClauseTaskDependAttr
+genDependKindAttr(fir::FirOpBuilder &firOpBuilder,
+ const Fortran::parser::OmpClause::Depend *dependClause) {
+ mlir::omp::ClauseTaskDepend pbKind;
+ switch (
+ std::get<Fortran::parser::OmpDependenceType>(
+ std::get<Fortran::parser::OmpDependClause::InOut>(dependClause->v.u)
+ .t)
+ .v) {
+ case Fortran::parser::OmpDependenceType::Type::In:
+ pbKind = mlir::omp::ClauseTaskDepend::taskdependin;
+ break;
+ case Fortran::parser::OmpDependenceType::Type::Out:
+ pbKind = mlir::omp::ClauseTaskDepend::taskdependout;
+ break;
+ case Fortran::parser::OmpDependenceType::Type::Inout:
+ pbKind = mlir::omp::ClauseTaskDepend::taskdependinout;
+ break;
+ default:
+ llvm_unreachable("unknown parser task dependence type");
+ break;
+ }
+ return mlir::omp::ClauseTaskDependAttr::get(firOpBuilder.getContext(),
+ pbKind);
+}
+
+static mlir::Value getIfClauseOperand(
+ Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::OmpClause::If *ifClause,
+ Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
+ mlir::Location clauseLocation) {
+ // Only consider the clause if it's intended for the given directive.
+ auto &directive = std::get<
+ std::optional<Fortran::parser::OmpIfClause::DirectiveNameModifier>>(
+ ifClause->v.t);
+ if (directive && directive.value() != directiveName)
+ return nullptr;
+
+ Fortran::lower::StatementContext stmtCtx;
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
+ mlir::Value ifVal = fir::getBase(
+ converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
+ return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(),
+ ifVal);
+}
+
+static void
+addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::OmpObjectList &useDeviceClause,
+ llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
+ llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ &useDeviceSymbols) {
+ genObjectList(useDeviceClause, converter, operands);
+ for (mlir::Value &operand : operands) {
+ checkMapType(operand.getLoc(), operand.getType());
+ useDeviceTypes.push_back(operand.getType());
+ useDeviceLocs.push_back(operand.getLoc());
+ }
+ for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) {
+ Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
+ useDeviceSymbols.push_back(sym);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// ClauseProcessor unique clauses
+//===----------------------------------------------------------------------===//
+
+bool ClauseProcessor::processCollapse(
+ mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
+ llvm::SmallVectorImpl<mlir::Value> &lowerBound,
+ llvm::SmallVectorImpl<mlir::Value> &upperBound,
+ llvm::SmallVectorImpl<mlir::Value> &step,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
+ std::size_t &loopVarTypeSize) const {
+ bool found = false;
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+ // Collect the loops to collapse.
+ Fortran::lower::pft::Evaluation *doConstructEval =
+ &eval.getFirstNestedEvaluation();
+ if (doConstructEval->getIf<Fortran::parser::DoConstruct>()
+ ->IsDoConcurrent()) {
+ TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
+ }
+
+ std::int64_t collapseValue = 1l;
+ if (auto *collapseClause = findUniqueClause<ClauseTy::Collapse>()) {
+ const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
+ collapseValue = Fortran::evaluate::ToInt64(*expr).value();
+ found = true;
+ }
+
+ loopVarTypeSize = 0;
+ do {
+ Fortran::lower::pft::Evaluation *doLoop =
+ &doConstructEval->getFirstNestedEvaluation();
+ auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
+ assert(doStmt && "Expected do loop to be in the nested evaluation");
+ const auto &loopControl =
+ std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
+ const Fortran::parser::LoopControl::Bounds *bounds =
+ std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
+ assert(bounds && "Expected bounds for worksharing do loop");
+ Fortran::lower::StatementContext stmtCtx;
+ lowerBound.push_back(fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
+ upperBound.push_back(fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
+ if (bounds->step) {
+ step.push_back(fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
+ } else { // If `step` is not present, assume it as `1`.
+ step.push_back(firOpBuilder.createIntegerConstant(
+ currentLocation, firOpBuilder.getIntegerType(32), 1));
+ }
+ iv.push_back(bounds->name.thing.symbol);
+ loopVarTypeSize = std::max(loopVarTypeSize,
+ bounds->name.thing.symbol->GetUltimate().size());
+ collapseValue--;
+ doConstructEval =
+ &*std::next(doConstructEval->getNestedEvaluations().begin());
+ } while (collapseValue > 0);
+
+ return found;
+}
+
+bool ClauseProcessor::processDefault() const {
+ if (auto *defaultClause = findUniqueClause<ClauseTy::Default>()) {
+ // Private, Firstprivate, Shared, None
+ switch (defaultClause->v.v) {
+ case Fortran::parser::OmpDefaultClause::Type::Shared:
+ case Fortran::parser::OmpDefaultClause::Type::None:
+ // Default clause with shared or none do not require any handling since
+ // Shared is the default behavior in the IR and None is only required
+ // for semantic checks.
+ break;
+ case Fortran::parser::OmpDefaultClause::Type::Private:
+ // TODO Support default(private)
+ break;
+ case Fortran::parser::OmpDefaultClause::Type::Firstprivate:
+ // TODO Support default(firstprivate)
+ break;
+ }
+ return true;
+ }
+ return false;
+}
+
+bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
+ mlir::Value &result) const {
+ const Fortran::parser::CharBlock *source = nullptr;
+ if (auto *deviceClause = findUniqueClause<ClauseTy::Device>(&source)) {
+ mlir::Location clauseLocation = converter.genLocation(*source);
+ if (auto deviceModifier = std::get<
+ std::optional<Fortran::parser::OmpDeviceClause::DeviceModifier>>(
+ deviceClause->v.t)) {
+ if (deviceModifier ==
+ Fortran::parser::OmpDeviceClause::DeviceModifier::Ancestor) {
+ TODO(clauseLocation, "OMPD_target Device Modifier Ancestor");
+ }
+ }
+ if (const auto *deviceExpr = Fortran::semantics::GetExpr(
+ std::get<Fortran::parser::ScalarIntExpr>(deviceClause->v.t))) {
+ result = fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx));
+ }
+ return true;
+ }
+ return false;
+}
+
+bool ClauseProcessor::processDeviceType(
+ mlir::omp::DeclareTargetDeviceType &result) const {
+ if (auto *deviceTypeClause = findUniqueClause<ClauseTy::DeviceType>()) {
+ // Case: declare target ... device_type(any | host | nohost)
+ switch (deviceTypeClause->v.v) {
+ case Fortran::parser::OmpDeviceTypeClause::Type::Nohost:
+ result = mlir::omp::DeclareTargetDeviceType::nohost;
+ break;
+ case Fortran::parser::OmpDeviceTypeClause::Type::Host:
+ result = mlir::omp::DeclareTargetDeviceType::host;
+ break;
+ case Fortran::parser::OmpDeviceTypeClause::Type::Any:
+ result = mlir::omp::DeclareTargetDeviceType::any;
+ break;
+ }
+ return true;
+ }
+ return false;
+}
+
+bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
+ mlir::Value &result) const {
+ const Fortran::parser::CharBlock *source = nullptr;
+ if (auto *finalClause = findUniqueClause<ClauseTy::Final>(&source)) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::Location clauseLocation = converter.genLocation(*source);
+
+ mlir::Value finalVal = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(finalClause->v), stmtCtx));
+ result = firOpBuilder.createConvert(clauseLocation,
+ firOpBuilder.getI1Type(), finalVal);
+ return true;
+ }
+ return false;
+}
+
+bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const {
+ if (auto *hintClause = findUniqueClause<ClauseTy::Hint>()) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ const auto *expr = Fortran::semantics::GetExpr(hintClause->v);
+ int64_t hintValue = *Fortran::evaluate::ToInt64(*expr);
+ result = firOpBuilder.getI64IntegerAttr(hintValue);
+ return true;
+ }
+ return false;
+}
+
+bool ClauseProcessor::processMergeable(mlir::UnitAttr &result) const {
+ return markClauseOccurrence<ClauseTy::Mergeable>(result);
+}
+
+bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const {
+ return markClauseOccurrence<ClauseTy::Nowait>(result);
+}
+
+bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx,
+ mlir::Value &result) const {
+ // TODO Get lower and upper bounds for num_teams when parser is updated to
+ // accept both.
+ if (auto *numTeamsClause = findUniqueClause<ClauseTy::NumTeams>()) {
+ result = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(numTeamsClause->v), stmtCtx));
+ return true;
+ }
+ return false;
+}
+
+bool ClauseProcessor::processNumThreads(
+ Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
+ if (auto *numThreadsClause = findUniqueClause<ClauseTy::NumThreads>()) {
+ // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
+ result = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
+ return true;
+ }
+ return false;
+}
+
+bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const {
+ if (auto *orderedClause = findUniqueClause<ClauseTy::Ordered>()) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ int64_t orderedClauseValue = 0l;
+ if (orderedClause->v.has_value()) {
+ const auto *expr = Fortran::semantics::GetExpr(orderedClause->v);
+ orderedClauseValue = *Fortran::evaluate::ToInt64(*expr);
+ }
+ result = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
+ return true;
+ }
+ return false;
+}
+
+bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx,
+ mlir::Value &result) const {
+ if (auto *priorityClause = findUniqueClause<ClauseTy::Priority>()) {
+ result = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx));
+ return true;
+ }
+ return false;
+}
+
+bool ClauseProcessor::processProcBind(
+ mlir::omp::ClauseProcBindKindAttr &result) const {
+ if (auto *procBindClause = findUniqueClause<ClauseTy::ProcBind>()) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ result = genProcBindKindAttr(firOpBuilder, procBindClause);
+ return true;
+ }
+ return false;
+}
+
+bool ClauseProcessor::processSafelen(mlir::IntegerAttr &result) const {
+ if (auto *safelenClause = findUniqueClause<ClauseTy::Safelen>()) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ const auto *expr = Fortran::semantics::GetExpr(safelenClause->v);
+ const std::optional<std::int64_t> safelenVal =
+ Fortran::evaluate::ToInt64(*expr);
+ result = firOpBuilder.getI64IntegerAttr(*safelenVal);
+ return true;
+ }
+ return false;
+}
+
+bool ClauseProcessor::processSchedule(
+ mlir::omp::ClauseScheduleKindAttr &valAttr,
+ mlir::omp::ScheduleModifierAttr &modifierAttr,
+ mlir::UnitAttr &simdModifierAttr) const {
+ if (auto *scheduleClause = findUniqueClause<ClauseTy::Schedule>()) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::MLIRContext *context = firOpBuilder.getContext();
+ const Fortran::parser::OmpScheduleClause &scheduleType = scheduleClause->v;
+ const auto &scheduleClauseKind =
+ std::get<Fortran::parser::OmpScheduleClause::ScheduleType>(
+ scheduleType.t);
+
+ mlir::omp::ClauseScheduleKind scheduleKind;
+ switch (scheduleClauseKind) {
+ case Fortran::parser::OmpScheduleClause::ScheduleType::Static:
+ scheduleKind = mlir::omp::ClauseScheduleKind::Static;
+ break;
+ case Fortran::parser::OmpScheduleClause::ScheduleType::Dynamic:
+ scheduleKind = mlir::omp::ClauseScheduleKind::Dynamic;
+ break;
+ case Fortran::parser::OmpScheduleClause::ScheduleType::Guided:
+ scheduleKind = mlir::omp::ClauseScheduleKind::Guided;
+ break;
+ case Fortran::parser::OmpScheduleClause::ScheduleType::Auto:
+ scheduleKind = mlir::omp::ClauseScheduleKind::Auto;
+ break;
+ case Fortran::parser::OmpScheduleClause::ScheduleType::Runtime:
+ scheduleKind = mlir::omp::ClauseScheduleKind::Runtime;
+ break;
+ }
+
+ mlir::omp::ScheduleModifier scheduleModifier =
+ getScheduleModifier(scheduleClause->v);
+
+ if (scheduleModifier != mlir::omp::ScheduleModifier::none)
+ modifierAttr =
+ mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier);
+
+ if (getSimdModifier(scheduleClause->v) != mlir::omp::ScheduleModifier::none)
+ simdModifierAttr = firOpBuilder.getUnitAttr();
+
+ valAttr = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
+ return true;
+ }
+ return false;
+}
+
+bool ClauseProcessor::processScheduleChunk(
+ Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
+ if (auto *scheduleClause = findUniqueClause<ClauseTy::Schedule>()) {
+ if (const auto &chunkExpr =
+ std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
+ scheduleClause->v.t)) {
+ if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) {
+ result = fir::getBase(converter.genExprValue(*expr, stmtCtx));
+ }
+ }
+ return true;
+ }
+ return false;
+}
+
+bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const {
+ if (auto *simdlenClause = findUniqueClause<ClauseTy::Simdlen>()) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ const auto *expr = Fortran::semantics::GetExpr(simdlenClause->v);
+ const std::optional<std::int64_t> simdlenVal =
+ Fortran::evaluate::ToInt64(*expr);
+ result = firOpBuilder.getI64IntegerAttr(*simdlenVal);
+ return true;
+ }
+ return false;
+}
+
+bool ClauseProcessor::processThreadLimit(
+ Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
+ if (auto *threadLmtClause = findUniqueClause<ClauseTy::ThreadLimit>()) {
+ result = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx));
+ return true;
+ }
+ return false;
+}
+
+bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
+ return markClauseOccurrence<ClauseTy::Untied>(result);
+}
+
+//===----------------------------------------------------------------------===//
+// ClauseProcessor repeatable clauses
+//===----------------------------------------------------------------------===//
+
+bool ClauseProcessor::processAllocate(
+ llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
+ llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const {
+ return findRepeatableClause<ClauseTy::Allocate>(
+ [&](const ClauseTy::Allocate *allocateClause,
+ const Fortran::parser::CharBlock &) {
+ genAllocateClause(converter, allocateClause->v, allocatorOperands,
+ allocateOperands);
+ });
+}
+
+bool ClauseProcessor::processCopyin() const {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint();
+ firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
+ auto checkAndCopyHostAssociateVar =
+ [&](Fortran::semantics::Symbol *sym,
+ mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) {
+ assert(sym->has<Fortran::semantics::HostAssocDetails>() &&
+ "No host-association found");
+ if (converter.isPresentShallowLookup(*sym))
+ converter.copyHostAssociateVar(*sym, copyAssignIP);
+ };
+ bool hasCopyin = findRepeatableClause<ClauseTy::Copyin>(
+ [&](const ClauseTy::Copyin *copyinClause,
+ const Fortran::parser::CharBlock &) {
+ const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v;
+ for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
+ Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
+ if (const auto *commonDetails =
+ sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
+ for (const auto &mem : commonDetails->objects())
+ checkAndCopyHostAssociateVar(&*mem, &insPt);
+ break;
+ }
+ if (Fortran::semantics::IsAllocatableOrObjectPointer(
+ &sym->GetUltimate()))
+ TODO(converter.getCurrentLocation(),
+ "pointer or allocatable variables in Copyin clause");
+ assert(sym->has<Fortran::semantics::HostAssocDetails>() &&
+ "No host-association found");
+ checkAndCopyHostAssociateVar(sym);
+ }
+ });
+
+ // [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to
+ // the execution of the associated structured block. Emit implicit barrier to
+ // synchronize threads and avoid data races on propagation master's thread
+ // values of threadprivate variables to local instances of that variables of
+ // all other implicit threads.
+ if (hasCopyin)
+ firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation());
+ firOpBuilder.restoreInsertionPoint(insPt);
+ return hasCopyin;
+}
+
+bool ClauseProcessor::processDepend(
+ llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
+ llvm::SmallVectorImpl<mlir::Value> &dependOperands) const {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+ return findRepeatableClause<ClauseTy::Depend>(
+ [&](const ClauseTy::Depend *dependClause,
+ const Fortran::parser::CharBlock &) {
+ const std::list<Fortran::parser::Designator> &depVal =
+ std::get<std::list<Fortran::parser::Designator>>(
+ std::get<Fortran::parser::OmpDependClause::InOut>(
+ dependClause->v.u)
+ .t);
+ mlir::omp::ClauseTaskDependAttr dependTypeOperand =
+ genDependKindAttr(firOpBuilder, dependClause);
+ dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(),
+ dependTypeOperand);
+ for (const Fortran::parser::Designator &ompObject : depVal) {
+ Fortran::semantics::Symbol *sym = nullptr;
+ std::visit(
+ Fortran::common::visitors{
+ [&](const Fortran::parser::DataRef &designator) {
+ if (const Fortran::parser::Name *name =
+ std::get_if<Fortran::parser::Name>(&designator.u)) {
+ sym = name->symbol;
+ } else if (std::get_if<Fortran::common::Indirection<
+ Fortran::parser::ArrayElement>>(
+ &designator.u)) {
+ TODO(converter.getCurrentLocation(),
+ "array sections not supported for task depend");
+ }
+ },
+ [&](const Fortran::parser::Substring &designator) {
+ TODO(converter.getCurrentLocation(),
+ "substring not supported for task depend");
+ }},
+ (ompObject).u);
+ const mlir::Value variable = converter.getSymbolAddress(*sym);
+ dependOperands.push_back(variable);
+ }
+ });
+}
+
+bool ClauseProcessor::processIf(
+ Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
+ mlir::Value &result) const {
+ bool found = false;
+ findRepeatableClause<ClauseTy::If>(
+ [&](const ClauseTy::If *ifClause,
+ const Fortran::parser::CharBlock &source) {
+ mlir::Location clauseLocation = converter.genLocation(source);
+ mlir::Value operand = getIfClauseOperand(converter, ifClause,
+ directiveName, clauseLocation);
+ // Assume that, at most, a single 'if' clause will be applicable to the
+ // given directive.
+ if (operand) {
+ result = operand;
+ found = true;
+ }
+ });
+ return found;
+}
+
+bool ClauseProcessor::processLink(
+ llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
+ return findRepeatableClause<ClauseTy::Link>(
+ [&](const ClauseTy::Link *linkClause,
+ const Fortran::parser::CharBlock &) {
+ // Case: declare target link(var1, var2)...
+ gatherFuncAndVarSyms(
+ linkClause->v, mlir::omp::DeclareTargetCaptureClause::link, result);
+ });
+}
+
+mlir::omp::MapInfoOp
+createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
+ mlir::SmallVector<mlir::Value> bounds,
+ mlir::SmallVector<mlir::Value> members, uint64_t mapType,
+ mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
+ bool isVal) {
+ if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
+ baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
+ retTy = baseAddr.getType();
+ }
+
+ mlir::TypeAttr varType = mlir::TypeAttr::get(
+ llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
+
+ mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
+ loc, retTy, baseAddr, varType, varPtrPtr, members, bounds,
+ builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
+ builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
+ builder.getStringAttr(name));
+
+ return op;
+}
+
+bool ClauseProcessor::processMap(
+ mlir::Location currentLocation, const llvm::omp::Directive &directive,
+ Fortran::lower::StatementContext &stmtCtx,
+ llvm::SmallVectorImpl<mlir::Value> &mapOperands,
+ llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
+ llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
+ const {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ return findRepeatableClause<ClauseTy::Map>(
+ [&](const ClauseTy::Map *mapClause,
+ const Fortran::parser::CharBlock &source) {
+ mlir::Location clauseLocation = converter.genLocation(source);
+ const auto &oMapType =
+ std::get<std::optional<Fortran::parser::OmpMapType>>(
+ mapClause->v.t);
+ llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
+ // If the map type is specified, then process it else Tofrom is the
+ // default.
+ if (oMapType) {
+ const Fortran::parser::OmpMapType::Type &mapType =
+ std::get<Fortran::parser::OmpMapType::Type>(oMapType->t);
+ switch (mapType) {
+ case Fortran::parser::OmpMapType::Type::To:
+ mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
+ break;
+ case Fortran::parser::OmpMapType::Type::From:
+ mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+ break;
+ case Fortran::parser::OmpMapType::Type::Tofrom:
+ mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+ break;
+ case Fortran::parser::OmpMapType::Type::Alloc:
+ case Fortran::parser::OmpMapType::Type::Release:
+ // alloc and release is the default map_type for the Target Data
+ // Ops, i.e. if no bits for map_type is supplied then alloc/release
+ // is implicitly assumed based on the target directive. Default
+ // value for Target Data and Enter Data is alloc and for Exit Data
+ // it is release.
+ break;
+ case Fortran::parser::OmpMapType::Type::Delete:
+ mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
+ }
+
+ if (std::get<std::optional<Fortran::parser::OmpMapType::Always>>(
+ oMapType->t))
+ mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
+ } else {
+ mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+ }
+
+ for (const Fortran::parser::OmpObject &ompObject :
+ std::get<Fortran::parser::OmpObjectList>(mapClause->v.t).v) {
+ llvm::SmallVector<mlir::Value> bounds;
+ std::stringstream asFortran;
+
+ Fortran::lower::AddrAndBoundsInfo info =
+ Fortran::lower::gatherDataOperandAddrAndBounds<
+ Fortran::parser::OmpObject, mlir::omp::DataBoundsOp,
+ mlir::omp::DataBoundsType>(
+ converter, firOpBuilder, semaCtx, stmtCtx, ompObject,
+ clauseLocation, asFortran, bounds, treatIndexAsSection);
+
+ auto origSymbol =
+ converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
+ mlir::Value symAddr = info.addr;
+ if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
+ symAddr = origSymbol;
+
+ // Explicit map captures are captured ByRef by default,
+ // optimisation passes may alter this to ByCopy or other capture
+ // types to optimise
+ mlir::Value mapOp = createMapInfoOp(
+ firOpBuilder, clauseLocation, symAddr, mlir::Value{},
+ asFortran.str(), bounds, {},
+ static_cast<
+ std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+ mapTypeBits),
+ mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
+
+ mapOperands.push_back(mapOp);
+ if (mapSymTypes)
+ mapSymTypes->push_back(symAddr.getType());
+ if (mapSymLocs)
+ mapSymLocs->push_back(symAddr.getLoc());
+
+ if (mapSymbols)
+ mapSymbols->push_back(getOmpObjectSymbol(ompObject));
+ }
+ });
+}
+
+bool ClauseProcessor::processReduction(
+ mlir::Location currentLocation,
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
+ const {
+ return findRepeatableClause<ClauseTy::Reduction>(
+ [&](const ClauseTy::Reduction *reductionClause,
+ const Fortran::parser::CharBlock &) {
+ ReductionProcessor rp;
+ rp.addReductionDecl(currentLocation, converter, reductionClause->v,
+ reductionVars, reductionDeclSymbols,
+ reductionSymbols);
+ });
+}
+
+bool ClauseProcessor::processSectionsReduction(
+ mlir::Location currentLocation) const {
+ return findRepeatableClause<ClauseTy::Reduction>(
+ [&](const ClauseTy::Reduction *, const Fortran::parser::CharBlock &) {
+ TODO(currentLocation, "OMPC_Reduction");
+ });
+}
+
+bool ClauseProcessor::processTo(
+ llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
+ return findRepeatableClause<ClauseTy::To>(
+ [&](const ClauseTy::To *toClause, const Fortran::parser::CharBlock &) {
+ // Case: declare target to(func, var1, var2)...
+ gatherFuncAndVarSyms(toClause->v,
+ mlir::omp::DeclareTargetCaptureClause::to, result);
+ });
+}
+
+bool ClauseProcessor::processEnter(
+ llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
+ return findRepeatableClause<ClauseTy::Enter>(
+ [&](const ClauseTy::Enter *enterClause,
+ const Fortran::parser::CharBlock &) {
+ // Case: declare target enter(func, var1, var2)...
+ gatherFuncAndVarSyms(enterClause->v,
+ mlir::omp::DeclareTargetCaptureClause::enter,
+ result);
+ });
+}
+
+bool ClauseProcessor::processUseDeviceAddr(
+ llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
+ llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
+ const {
+ return findRepeatableClause<ClauseTy::UseDeviceAddr>(
+ [&](const ClauseTy::UseDeviceAddr *devAddrClause,
+ const Fortran::parser::CharBlock &) {
+ addUseDeviceClause(converter, devAddrClause->v, operands,
+ useDeviceTypes, useDeviceLocs, useDeviceSymbols);
+ });
+}
+
+bool ClauseProcessor::processUseDevicePtr(
+ llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
+ llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
+ const {
+ return findRepeatableClause<ClauseTy::UseDevicePtr>(
+ [&](const ClauseTy::UseDevicePtr *devPtrClause,
+ const Fortran::parser::CharBlock &) {
+ addUseDeviceClause(converter, devPtrClause->v, operands, useDeviceTypes,
+ useDeviceLocs, useDeviceSymbols);
+ });
+}
+} // namespace omp
+} // namespace lower
+} // namespace Fortran
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
new file mode 100644
index 00000000000000..312255112605e3
--- /dev/null
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -0,0 +1,305 @@
+//===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+#ifndef FORTRAN_LOWER_CLAUASEPROCESSOR_H
+#define FORTRAN_LOWER_CLAUASEPROCESSOR_H
+
+#include "DirectivesCommon.h"
+#include "ReductionProcessor.h"
+#include "Utils.h"
+#include "flang/Lower/AbstractConverter.h"
+#include "flang/Lower/Bridge.h"
+#include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Parser/dump-parse-tree.h"
+#include "flang/Parser/parse-tree.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+
+namespace fir {
+class FirOpBuilder;
+} // namespace fir
+
+namespace Fortran {
+namespace lower {
+namespace omp {
+
+/// Class that handles the processing of OpenMP clauses.
+///
+/// Its `process<ClauseName>()` methods perform MLIR code generation for their
+/// corresponding clause if it is present in the clause list. Otherwise, they
+/// will return `false` to signal that the clause was not found.
+///
+/// The intended use is of this class is to move clause processing outside of
+/// construct processing, since the same clauses can appear attached to
+/// different constructs and constructs can be combined, so that code
+/// duplication is minimized.
+///
+/// Each construct-lowering function only calls the `process<ClauseName>()`
+/// methods that relate to clauses that can impact the lowering of that
+/// construct.
+class ClauseProcessor {
+ using ClauseTy = Fortran::parser::OmpClause;
+
+public:
+ ClauseProcessor(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const Fortran::parser::OmpClauseList &clauses)
+ : converter(converter), semaCtx(semaCtx), clauses(clauses) {}
+
+ // 'Unique' clauses: They can appear at most once in the clause list.
+ bool
+ processCollapse(mlir::Location currentLocation,
+ Fortran::lower::pft::Evaluation &eval,
+ llvm::SmallVectorImpl<mlir::Value> &lowerBound,
+ llvm::SmallVectorImpl<mlir::Value> &upperBound,
+ llvm::SmallVectorImpl<mlir::Value> &step,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
+ std::size_t &loopVarTypeSize) const;
+ bool processDefault() const;
+ bool processDevice(Fortran::lower::StatementContext &stmtCtx,
+ mlir::Value &result) const;
+ bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const;
+ bool processFinal(Fortran::lower::StatementContext &stmtCtx,
+ mlir::Value &result) const;
+ bool processHint(mlir::IntegerAttr &result) const;
+ bool processMergeable(mlir::UnitAttr &result) const;
+ bool processNowait(mlir::UnitAttr &result) const;
+ bool processNumTeams(Fortran::lower::StatementContext &stmtCtx,
+ mlir::Value &result) const;
+ bool processNumThreads(Fortran::lower::StatementContext &stmtCtx,
+ mlir::Value &result) const;
+ bool processOrdered(mlir::IntegerAttr &result) const;
+ bool processPriority(Fortran::lower::StatementContext &stmtCtx,
+ mlir::Value &result) const;
+ bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const;
+ bool processSafelen(mlir::IntegerAttr &result) const;
+ bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr,
+ mlir::omp::ScheduleModifierAttr &modifierAttr,
+ mlir::UnitAttr &simdModifierAttr) const;
+ bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx,
+ mlir::Value &result) const;
+ bool processSimdlen(mlir::IntegerAttr &result) const;
+ bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx,
+ mlir::Value &result) const;
+ bool processUntied(mlir::UnitAttr &result) const;
+
+ // 'Repeatable' clauses: They can appear multiple times in the clause list.
+ bool
+ processAllocate(llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
+ llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const;
+ bool processCopyin() const;
+ bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
+ llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
+ bool
+ processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
+ bool
+ processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
+ mlir::Value &result) const;
+ bool
+ processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
+
+ // This method is used to process a map clause.
+ // The optional parameters - mapSymTypes, mapSymLocs & mapSymbols are used to
+ // store the original type, location and Fortran symbol for the map operands.
+ // They may be used later on to create the block_arguments for some of the
+ // target directives that require it.
+ bool processMap(mlir::Location currentLocation,
+ const llvm::omp::Directive &directive,
+ Fortran::lower::StatementContext &stmtCtx,
+ llvm::SmallVectorImpl<mlir::Value> &mapOperands,
+ llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
+ llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ *mapSymbols = nullptr) const;
+ bool
+ processReduction(mlir::Location currentLocation,
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ *reductionSymbols = nullptr) const;
+ bool processSectionsReduction(mlir::Location currentLocation) const;
+ bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
+ bool
+ processUseDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
+ llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ &useDeviceSymbols) const;
+ bool
+ processUseDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
+ llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ &useDeviceSymbols) const;
+
+ template <typename T>
+ bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx,
+ llvm::SmallVectorImpl<mlir::Value> &mapOperands);
+
+ // Call this method for these clauses that should be supported but are not
+ // implemented yet. It triggers a compilation error if any of the given
+ // clauses is found.
+ template <typename... Ts>
+ void processTODO(mlir::Location currentLocation,
+ llvm::omp::Directive directive) const;
+
+private:
+ using ClauseIterator = std::list<ClauseTy>::const_iterator;
+
+ /// Utility to find a clause within a range in the clause list.
+ template <typename T>
+ static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end);
+
+ /// Return the first instance of the given clause found in the clause list or
+ /// `nullptr` if not present. If more than one instance is expected, use
+ /// `findRepeatableClause` instead.
+ template <typename T>
+ const T *
+ findUniqueClause(const Fortran::parser::CharBlock **source = nullptr) const;
+
+ /// Call `callbackFn` for each occurrence of the given clause. Return `true`
+ /// if at least one instance was found.
+ template <typename T>
+ bool findRepeatableClause(
+ std::function<void(const T *, const Fortran::parser::CharBlock &source)>
+ callbackFn) const;
+
+ /// Set the `result` to a new `mlir::UnitAttr` if the clause is present.
+ template <typename T>
+ bool markClauseOccurrence(mlir::UnitAttr &result) const;
+
+ Fortran::lower::AbstractConverter &converter;
+ Fortran::semantics::SemanticsContext &semaCtx;
+ const Fortran::parser::OmpClauseList &clauses;
+};
+
+template <typename T>
+bool ClauseProcessor::processMotionClauses(
+ Fortran::lower::StatementContext &stmtCtx,
+ llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
+ return findRepeatableClause<T>(
+ [&](const T *motionClause, const Fortran::parser::CharBlock &source) {
+ mlir::Location clauseLocation = converter.genLocation(source);
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+ static_assert(std::is_same_v<T, ClauseProcessor::ClauseTy::To> ||
+ std::is_same_v<T, ClauseProcessor::ClauseTy::From>);
+
+ // TODO Support motion modifiers: present, mapper, iterator.
+ constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
+ std::is_same_v<T, ClauseProcessor::ClauseTy::To>
+ ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
+ : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+
+ for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) {
+ llvm::SmallVector<mlir::Value> bounds;
+ std::stringstream asFortran;
+ Fortran::lower::AddrAndBoundsInfo info =
+ Fortran::lower::gatherDataOperandAddrAndBounds<
+ Fortran::parser::OmpObject, mlir::omp::DataBoundsOp,
+ mlir::omp::DataBoundsType>(
+ converter, firOpBuilder, semaCtx, stmtCtx, ompObject,
+ clauseLocation, asFortran, bounds, treatIndexAsSection);
+
+ auto origSymbol =
+ converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
+ mlir::Value symAddr = info.addr;
+ if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
+ symAddr = origSymbol;
+
+ // Explicit map captures are captured ByRef by default,
+ // optimisation passes may alter this to ByCopy or other capture
+ // types to optimise
+ mlir::Value mapOp = createMapInfoOp(
+ firOpBuilder, clauseLocation, symAddr, mlir::Value{},
+ asFortran.str(), bounds, {},
+ static_cast<
+ std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+ mapTypeBits),
+ mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
+
+ mapOperands.push_back(mapOp);
+ }
+ });
+}
+
+template <typename... Ts>
+void ClauseProcessor::processTODO(mlir::Location currentLocation,
+ llvm::omp::Directive directive) const {
+ auto checkUnhandledClause = [&](const auto *x) {
+ if (!x)
+ return;
+ TODO(currentLocation,
+ "Unhandled clause " +
+ llvm::StringRef(Fortran::parser::ParseTreeDumper::GetNodeName(*x))
+ .upper() +
+ " in " + llvm::omp::getOpenMPDirectiveName(directive).upper() +
+ " construct");
+ };
+
+ for (ClauseIterator it = clauses.v.begin(); it != clauses.v.end(); ++it)
+ (checkUnhandledClause(std::get_if<Ts>(&it->u)), ...);
+}
+
+template <typename T>
+ClauseProcessor::ClauseIterator
+ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) {
+ for (ClauseIterator it = begin; it != end; ++it) {
+ if (std::get_if<T>(&it->u))
+ return it;
+ }
+
+ return end;
+}
+
+template <typename T>
+const T *ClauseProcessor::findUniqueClause(
+ const Fortran::parser::CharBlock **source) const {
+ ClauseIterator it = findClause<T>(clauses.v.begin(), clauses.v.end());
+ if (it != clauses.v.end()) {
+ if (source)
+ *source = &it->source;
+ return &std::get<T>(it->u);
+ }
+ return nullptr;
+}
+
+template <typename T>
+bool ClauseProcessor::findRepeatableClause(
+ std::function<void(const T *, const Fortran::parser::CharBlock &source)>
+ callbackFn) const {
+ bool found = false;
+ ClauseIterator nextIt, endIt = clauses.v.end();
+ for (ClauseIterator it = clauses.v.begin(); it != endIt; it = nextIt) {
+ nextIt = findClause<T>(it, endIt);
+
+ if (nextIt != endIt) {
+ callbackFn(&std::get<T>(nextIt->u), nextIt->source);
+ found = true;
+ ++nextIt;
+ }
+ }
+ return found;
+}
+
+template <typename T>
+bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const {
+ if (findUniqueClause<T>()) {
+ result = converter.getFirOpBuilder().getUnitAttr();
+ return true;
+ }
+ return false;
+}
+
+} // namespace omp
+} // namespace lower
+} // namespace Fortran
+
+#endif // FORTRAN_LOWER_CLAUASEPROCESSOR_H
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
new file mode 100644
index 00000000000000..136bda0b582ee3
--- /dev/null
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -0,0 +1,350 @@
+//===-- DataSharingProcessor.cpp --------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#include "DataSharingProcessor.h"
+
+#include "Utils.h"
+#include "flang/Lower/PFTBuilder.h"
+#include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Semantics/tools.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+
+namespace Fortran {
+namespace lower {
+namespace omp {
+
+void DataSharingProcessor::processStep1() {
+ collectSymbolsForPrivatization();
+ collectDefaultSymbols();
+ privatize();
+ defaultPrivatize();
+ insertBarrier();
+}
+
+void DataSharingProcessor::processStep2(mlir::Operation *op, bool isLoop) {
+ insPt = firOpBuilder.saveInsertionPoint();
+ copyLastPrivatize(op);
+ firOpBuilder.restoreInsertionPoint(insPt);
+
+ if (isLoop) {
+ // push deallocs out of the loop
+ firOpBuilder.setInsertionPointAfter(op);
+ insertDeallocs();
+ } else {
+ // insert dummy instruction to mark the insertion position
+ mlir::Value undefMarker = firOpBuilder.create<fir::UndefOp>(
+ op->getLoc(), firOpBuilder.getIndexType());
+ insertDeallocs();
+ firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
+ }
+}
+
+void DataSharingProcessor::insertDeallocs() {
+ for (const Fortran::semantics::Symbol *sym : privatizedSymbols)
+ if (Fortran::semantics::IsAllocatable(sym->GetUltimate())) {
+ converter.createHostAssociateVarCloneDealloc(*sym);
+ }
+}
+
+void DataSharingProcessor::cloneSymbol(const Fortran::semantics::Symbol *sym) {
+ // Privatization for symbols which are pre-determined (like loop index
+ // variables) happen separately, for everything else privatize here.
+ if (sym->test(Fortran::semantics::Symbol::Flag::OmpPreDetermined))
+ return;
+ bool success = converter.createHostAssociateVarClone(*sym);
+ (void)success;
+ assert(success && "Privatization failed due to existing binding");
+}
+
+void DataSharingProcessor::copyFirstPrivateSymbol(
+ const Fortran::semantics::Symbol *sym) {
+ if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate))
+ converter.copyHostAssociateVar(*sym);
+}
+
+void DataSharingProcessor::copyLastPrivateSymbol(
+ const Fortran::semantics::Symbol *sym,
+ [[maybe_unused]] mlir::OpBuilder::InsertPoint *lastPrivIP) {
+ if (sym->test(Fortran::semantics::Symbol::Flag::OmpLastPrivate))
+ converter.copyHostAssociateVar(*sym, lastPrivIP);
+}
+
+void DataSharingProcessor::collectOmpObjectListSymbol(
+ const Fortran::parser::OmpObjectList &ompObjectList,
+ llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet) {
+ for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
+ Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
+ symbolSet.insert(sym);
+ }
+}
+
+void DataSharingProcessor::collectSymbolsForPrivatization() {
+ bool hasCollapse = false;
+ for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
+ if (const auto &privateClause =
+ std::get_if<Fortran::parser::OmpClause::Private>(&clause.u)) {
+ collectOmpObjectListSymbol(privateClause->v, privatizedSymbols);
+ } else if (const auto &firstPrivateClause =
+ std::get_if<Fortran::parser::OmpClause::Firstprivate>(
+ &clause.u)) {
+ collectOmpObjectListSymbol(firstPrivateClause->v, privatizedSymbols);
+ } else if (const auto &lastPrivateClause =
+ std::get_if<Fortran::parser::OmpClause::Lastprivate>(
+ &clause.u)) {
+ collectOmpObjectListSymbol(lastPrivateClause->v, privatizedSymbols);
+ hasLastPrivateOp = true;
+ } else if (std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) {
+ hasCollapse = true;
+ }
+ }
+
+ if (hasCollapse && hasLastPrivateOp)
+ TODO(converter.getCurrentLocation(), "Collapse clause with lastprivate");
+}
+
+bool DataSharingProcessor::needBarrier() {
+ for (const Fortran::semantics::Symbol *sym : privatizedSymbols) {
+ if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate) &&
+ sym->test(Fortran::semantics::Symbol::Flag::OmpLastPrivate))
+ return true;
+ }
+ return false;
+}
+
+void DataSharingProcessor::insertBarrier() {
+ // Emit implicit barrier to synchronize threads and avoid data races on
+ // initialization of firstprivate variables and post-update of lastprivate
+ // variables.
+ // FIXME: Emit barrier for lastprivate clause when 'sections' directive has
+ // 'nowait' clause. Otherwise, emit barrier when 'sections' directive has
+ // both firstprivate and lastprivate clause.
+ // Emit implicit barrier for linear clause. Maybe on somewhere else.
+ if (needBarrier())
+ firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation());
+}
+
+void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
+ bool cmpCreated = false;
+ mlir::OpBuilder::InsertPoint localInsPt = firOpBuilder.saveInsertionPoint();
+ for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
+ if (std::get_if<Fortran::parser::OmpClause::Lastprivate>(&clause.u)) {
+ // TODO: Add lastprivate support for simd construct
+ if (mlir::isa<mlir::omp::SectionOp>(op)) {
+ if (&eval == &eval.parentConstruct->getLastNestedEvaluation()) {
+ // For `omp.sections`, lastprivatized variables occur in
+ // lexically final `omp.section` operation. The following FIR
+ // shall be generated for the same:
+ //
+ // omp.sections lastprivate(...) {
+ // omp.section {...}
+ // omp.section {...}
+ // omp.section {
+ // fir.allocate for `private`/`firstprivate`
+ // <More operations here>
+ // fir.if %true {
+ // ^%lpv_update_blk
+ // }
+ // }
+ // }
+ //
+ // To keep code consistency while handling privatization
+ // through this control flow, add a `fir.if` operation
+ // that always evaluates to true, in order to create
+ // a dedicated sub-region in `omp.section` where
+ // lastprivate FIR can reside. Later canonicalizations
+ // will optimize away this operation.
+ if (!eval.lowerAsUnstructured()) {
+ auto ifOp = firOpBuilder.create<fir::IfOp>(
+ op->getLoc(),
+ firOpBuilder.createIntegerConstant(
+ op->getLoc(), firOpBuilder.getIntegerType(1), 0x1),
+ /*else*/ false);
+ firOpBuilder.setInsertionPointToStart(
+ &ifOp.getThenRegion().front());
+
+ const Fortran::parser::OpenMPConstruct *parentOmpConstruct =
+ eval.parentConstruct->getIf<Fortran::parser::OpenMPConstruct>();
+ assert(parentOmpConstruct &&
+ "Expected a valid enclosing OpenMP construct");
+ const Fortran::parser::OpenMPSectionsConstruct *sectionsConstruct =
+ std::get_if<Fortran::parser::OpenMPSectionsConstruct>(
+ &parentOmpConstruct->u);
+ assert(sectionsConstruct &&
+ "Expected an enclosing omp.sections construct");
+ const Fortran::parser::OmpClauseList §ionsEndClauseList =
+ std::get<Fortran::parser::OmpClauseList>(
+ std::get<Fortran::parser::OmpEndSectionsDirective>(
+ sectionsConstruct->t)
+ .t);
+ for (const Fortran::parser::OmpClause &otherClause :
+ sectionsEndClauseList.v)
+ if (std::get_if<Fortran::parser::OmpClause::Nowait>(
+ &otherClause.u))
+ // Emit implicit barrier to synchronize threads and avoid data
+ // races on post-update of lastprivate variables when `nowait`
+ // clause is present.
+ firOpBuilder.create<mlir::omp::BarrierOp>(
+ converter.getCurrentLocation());
+ firOpBuilder.setInsertionPointToStart(
+ &ifOp.getThenRegion().front());
+ lastPrivIP = firOpBuilder.saveInsertionPoint();
+ firOpBuilder.setInsertionPoint(ifOp);
+ insPt = firOpBuilder.saveInsertionPoint();
+ } else {
+ // Lastprivate operation is inserted at the end
+ // of the lexically last section in the sections
+ // construct
+ mlir::OpBuilder::InsertPoint unstructuredSectionsIP =
+ firOpBuilder.saveInsertionPoint();
+ mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
+ firOpBuilder.setInsertionPoint(lastOper);
+ lastPrivIP = firOpBuilder.saveInsertionPoint();
+ firOpBuilder.restoreInsertionPoint(unstructuredSectionsIP);
+ }
+ }
+ } else if (mlir::isa<mlir::omp::WsLoopOp>(op)) {
+ // Update the original variable just before exiting the worksharing
+ // loop. Conversion as follows:
+ //
+ // omp.wsloop {
+ // omp.wsloop { ...
+ // ... store
+ // store ===> %v = arith.addi %iv, %step
+ // omp.yield %cmp = %step < 0 ? %v < %ub : %v > %ub
+ // } fir.if %cmp {
+ // fir.store %v to %loopIV
+ // ^%lpv_update_blk:
+ // }
+ // omp.yield
+ // }
+ //
+
+ // Only generate the compare once in presence of multiple LastPrivate
+ // clauses.
+ if (cmpCreated)
+ continue;
+ cmpCreated = true;
+
+ mlir::Location loc = op->getLoc();
+ mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
+ firOpBuilder.setInsertionPoint(lastOper);
+
+ mlir::Value iv = op->getRegion(0).front().getArguments()[0];
+ mlir::Value ub =
+ mlir::dyn_cast<mlir::omp::WsLoopOp>(op).getUpperBound()[0];
+ mlir::Value step = mlir::dyn_cast<mlir::omp::WsLoopOp>(op).getStep()[0];
+
+ // v = iv + step
+ // cmp = step < 0 ? v < ub : v > ub
+ mlir::Value v = firOpBuilder.create<mlir::arith::AddIOp>(loc, iv, step);
+ mlir::Value zero =
+ firOpBuilder.createIntegerConstant(loc, step.getType(), 0);
+ mlir::Value negativeStep = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::slt, step, zero);
+ mlir::Value vLT = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::slt, v, ub);
+ mlir::Value vGT = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::sgt, v, ub);
+ mlir::Value cmpOp = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, negativeStep, vLT, vGT);
+
+ auto ifOp = firOpBuilder.create<fir::IfOp>(loc, cmpOp, /*else*/ false);
+ firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ assert(loopIV && "loopIV was not set");
+ firOpBuilder.create<fir::StoreOp>(op->getLoc(), v, loopIV);
+ lastPrivIP = firOpBuilder.saveInsertionPoint();
+ } else {
+ TODO(converter.getCurrentLocation(),
+ "lastprivate clause in constructs other than "
+ "simd/worksharing-loop");
+ }
+ }
+ }
+ firOpBuilder.restoreInsertionPoint(localInsPt);
+}
+
+void DataSharingProcessor::collectSymbols(
+ Fortran::semantics::Symbol::Flag flag) {
+ converter.collectSymbolSet(eval, defaultSymbols, flag,
+ /*collectSymbols=*/true,
+ /*collectHostAssociatedSymbols=*/true);
+ for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) {
+ if (e.hasNestedEvaluations())
+ converter.collectSymbolSet(e, symbolsInNestedRegions, flag,
+ /*collectSymbols=*/true,
+ /*collectHostAssociatedSymbols=*/false);
+ else
+ converter.collectSymbolSet(e, symbolsInParentRegions, flag,
+ /*collectSymbols=*/false,
+ /*collectHostAssociatedSymbols=*/true);
+ }
+}
+
+void DataSharingProcessor::collectDefaultSymbols() {
+ for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
+ if (const auto &defaultClause =
+ std::get_if<Fortran::parser::OmpClause::Default>(&clause.u)) {
+ if (defaultClause->v.v ==
+ Fortran::parser::OmpDefaultClause::Type::Private)
+ collectSymbols(Fortran::semantics::Symbol::Flag::OmpPrivate);
+ else if (defaultClause->v.v ==
+ Fortran::parser::OmpDefaultClause::Type::Firstprivate)
+ collectSymbols(Fortran::semantics::Symbol::Flag::OmpFirstPrivate);
+ }
+ }
+}
+
+void DataSharingProcessor::privatize() {
+ for (const Fortran::semantics::Symbol *sym : privatizedSymbols) {
+ if (const auto *commonDet =
+ sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
+ for (const auto &mem : commonDet->objects()) {
+ cloneSymbol(&*mem);
+ copyFirstPrivateSymbol(&*mem);
+ }
+ } else {
+ cloneSymbol(sym);
+ copyFirstPrivateSymbol(sym);
+ }
+ }
+}
+
+void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) {
+ insertLastPrivateCompare(op);
+ for (const Fortran::semantics::Symbol *sym : privatizedSymbols)
+ if (const auto *commonDet =
+ sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
+ for (const auto &mem : commonDet->objects()) {
+ copyLastPrivateSymbol(&*mem, &lastPrivIP);
+ }
+ } else {
+ copyLastPrivateSymbol(sym, &lastPrivIP);
+ }
+}
+
+void DataSharingProcessor::defaultPrivatize() {
+ for (const Fortran::semantics::Symbol *sym : defaultSymbols) {
+ if (!Fortran::semantics::IsProcedure(*sym) &&
+ !sym->GetUltimate().has<Fortran::semantics::DerivedTypeDetails>() &&
+ !sym->GetUltimate().has<Fortran::semantics::NamelistDetails>() &&
+ !symbolsInNestedRegions.contains(sym) &&
+ !symbolsInParentRegions.contains(sym) &&
+ !privatizedSymbols.contains(sym)) {
+ cloneSymbol(sym);
+ copyFirstPrivateSymbol(sym);
+ }
+ }
+}
+
+} // namespace omp
+} // namespace lower
+} // namespace Fortran
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
new file mode 100644
index 00000000000000..10c0a30c09c391
--- /dev/null
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
@@ -0,0 +1,89 @@
+//===-- Lower/OpenMP/DataSharingProcessor.h ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+#ifndef FORTRAN_LOWER_DATASHARINGPROCESSOR_H
+#define FORTRAN_LOWER_DATASHARINGPROCESSOR_H
+
+#include "flang/Lower/AbstractConverter.h"
+#include "flang/Lower/OpenMP.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Parser/parse-tree.h"
+#include "flang/Semantics/symbol.h"
+
+namespace Fortran {
+namespace lower {
+namespace omp {
+
+class DataSharingProcessor {
+ bool hasLastPrivateOp;
+ mlir::OpBuilder::InsertPoint lastPrivIP;
+ mlir::OpBuilder::InsertPoint insPt;
+ mlir::Value loopIV;
+ // Symbols in private, firstprivate, and/or lastprivate clauses.
+ llvm::SetVector<const Fortran::semantics::Symbol *> privatizedSymbols;
+ llvm::SetVector<const Fortran::semantics::Symbol *> defaultSymbols;
+ llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInNestedRegions;
+ llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInParentRegions;
+ Fortran::lower::AbstractConverter &converter;
+ fir::FirOpBuilder &firOpBuilder;
+ const Fortran::parser::OmpClauseList &opClauseList;
+ Fortran::lower::pft::Evaluation &eval;
+
+ bool needBarrier();
+ void collectSymbols(Fortran::semantics::Symbol::Flag flag);
+ void collectOmpObjectListSymbol(
+ const Fortran::parser::OmpObjectList &ompObjectList,
+ llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet);
+ void collectSymbolsForPrivatization();
+ void insertBarrier();
+ void collectDefaultSymbols();
+ void privatize();
+ void defaultPrivatize();
+ void copyLastPrivatize(mlir::Operation *op);
+ void insertLastPrivateCompare(mlir::Operation *op);
+ void cloneSymbol(const Fortran::semantics::Symbol *sym);
+ void copyFirstPrivateSymbol(const Fortran::semantics::Symbol *sym);
+ void copyLastPrivateSymbol(const Fortran::semantics::Symbol *sym,
+ mlir::OpBuilder::InsertPoint *lastPrivIP);
+ void insertDeallocs();
+
+public:
+ DataSharingProcessor(Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::OmpClauseList &opClauseList,
+ Fortran::lower::pft::Evaluation &eval)
+ : hasLastPrivateOp(false), converter(converter),
+ firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList),
+ eval(eval) {}
+ // Privatisation is split into two steps.
+ // Step1 performs cloning of all privatisation clauses and copying for
+ // firstprivates. Step1 is performed at the place where process/processStep1
+ // is called. This is usually inside the Operation corresponding to the OpenMP
+ // construct, for looping constructs this is just before the Operation. The
+ // split into two steps was performed basically to be able to call
+ // privatisation for looping constructs before the operation is created since
+ // the bounds of the MLIR OpenMP operation can be privatised.
+ // Step2 performs the copying for lastprivates and requires knowledge of the
+ // MLIR operation to insert the last private update. Step2 adds
+ // dealocation code as well.
+ void processStep1();
+ void processStep2(mlir::Operation *op, bool isLoop);
+
+ void setLoopIV(mlir::Value iv) {
+ assert(!loopIV && "Loop iteration variable already set");
+ loopIV = iv;
+ }
+};
+
+} // namespace omp
+} // namespace lower
+} // namespace Fortran
+
+#endif // FORTRAN_LOWER_DATASHARINGPROCESSOR_H
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
similarity index 55%
rename from flang/lib/Lower/OpenMP.cpp
rename to flang/lib/Lower/OpenMP/OpenMP.cpp
index 9397af8b8bd05e..3aefad6cf0ec1f 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -11,109 +11,36 @@
//===----------------------------------------------------------------------===//
#include "flang/Lower/OpenMP.h"
+
+#include "ClauseProcessor.h"
+#include "DataSharingProcessor.h"
#include "DirectivesCommon.h"
+#include "ReductionProcessor.h"
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertVariable.h"
-#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
-#include "flang/Parser/dump-parse-tree.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/openmp-directive-sets.h"
#include "flang/Semantics/tools.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
-#include "llvm/Support/CommandLine.h"
-
-static llvm::cl::opt<bool> treatIndexAsSection(
- "openmp-treat-index-as-section",
- llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
- llvm::cl::init(true));
-using DeclareTargetCapturePair =
- std::pair<mlir::omp::DeclareTargetCaptureClause,
- Fortran::semantics::Symbol>;
+using namespace Fortran::lower::omp;
//===----------------------------------------------------------------------===//
-// Common helper functions
+// Code generation helper functions
//===----------------------------------------------------------------------===//
-static Fortran::semantics::Symbol *
-getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
- Fortran::semantics::Symbol *sym = nullptr;
- std::visit(
- Fortran::common::visitors{
- [&](const Fortran::parser::Designator &designator) {
- if (auto *arrayEle =
- Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
- designator)) {
- sym = GetFirstName(arrayEle->base).symbol;
- } else if (auto *structComp = Fortran::parser::Unwrap<
- Fortran::parser::StructureComponent>(designator)) {
- sym = structComp->component.symbol;
- } else if (const Fortran::parser::Name *name =
- Fortran::semantics::getDesignatorNameIfDataRef(
- designator)) {
- sym = name->symbol;
- }
- },
- [&](const Fortran::parser::Name &name) { sym = name.symbol; }},
- ompObject.u);
- return sym;
-}
-
-static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
- Fortran::lower::AbstractConverter &converter,
- llvm::SmallVectorImpl<mlir::Value> &operands) {
- auto addOperands = [&](Fortran::lower::SymbolRef sym) {
- const mlir::Value variable = converter.getSymbolAddress(sym);
- if (variable) {
- operands.push_back(variable);
- } else {
- if (const auto *details =
- sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
- operands.push_back(converter.getSymbolAddress(details->symbol()));
- converter.copySymbolBinding(details->symbol(), sym);
- }
- }
- };
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
- addOperands(*sym);
- }
-}
-
-static void gatherFuncAndVarSyms(
- const Fortran::parser::OmpObjectList &objList,
- mlir::omp::DeclareTargetCaptureClause clause,
- llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
- for (const Fortran::parser::OmpObject &ompObject : objList.v) {
- Fortran::common::visit(
- Fortran::common::visitors{
- [&](const Fortran::parser::Designator &designator) {
- if (const Fortran::parser::Name *name =
- Fortran::semantics::getDesignatorNameIfDataRef(
- designator)) {
- symbolAndClause.emplace_back(clause, *name->symbol);
- }
- },
- [&](const Fortran::parser::Name &name) {
- symbolAndClause.emplace_back(clause, *name.symbol);
- }},
- ompObject.u);
- }
-}
-
static Fortran::lower::pft::Evaluation *
getCollapsedLoopEval(Fortran::lower::pft::Evaluation &eval, int collapseValue) {
// Return the Evaluation of the innermost collapsed loop, or the current one
@@ -142,1961 +69,6 @@ static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter,
converter.genEval(e);
}
-//===----------------------------------------------------------------------===//
-// DataSharingProcessor
-//===----------------------------------------------------------------------===//
-
-class DataSharingProcessor {
- bool hasLastPrivateOp;
- mlir::OpBuilder::InsertPoint lastPrivIP;
- mlir::OpBuilder::InsertPoint insPt;
- mlir::Value loopIV;
- // Symbols in private, firstprivate, and/or lastprivate clauses.
- llvm::SetVector<const Fortran::semantics::Symbol *> privatizedSymbols;
- llvm::SetVector<const Fortran::semantics::Symbol *> defaultSymbols;
- llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInNestedRegions;
- llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInParentRegions;
- Fortran::lower::AbstractConverter &converter;
- fir::FirOpBuilder &firOpBuilder;
- const Fortran::parser::OmpClauseList &opClauseList;
- Fortran::lower::pft::Evaluation &eval;
-
- bool needBarrier();
- void collectSymbols(Fortran::semantics::Symbol::Flag flag);
- void collectOmpObjectListSymbol(
- const Fortran::parser::OmpObjectList &ompObjectList,
- llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet);
- void collectSymbolsForPrivatization();
- void insertBarrier();
- void collectDefaultSymbols();
- void privatize();
- void defaultPrivatize();
- void copyLastPrivatize(mlir::Operation *op);
- void insertLastPrivateCompare(mlir::Operation *op);
- void cloneSymbol(const Fortran::semantics::Symbol *sym);
- void copyFirstPrivateSymbol(const Fortran::semantics::Symbol *sym);
- void copyLastPrivateSymbol(const Fortran::semantics::Symbol *sym,
- mlir::OpBuilder::InsertPoint *lastPrivIP);
- void insertDeallocs();
-
-public:
- DataSharingProcessor(Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpClauseList &opClauseList,
- Fortran::lower::pft::Evaluation &eval)
- : hasLastPrivateOp(false), converter(converter),
- firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList),
- eval(eval) {}
- // Privatisation is split into two steps.
- // Step1 performs cloning of all privatisation clauses and copying for
- // firstprivates. Step1 is performed at the place where process/processStep1
- // is called. This is usually inside the Operation corresponding to the OpenMP
- // construct, for looping constructs this is just before the Operation. The
- // split into two steps was performed basically to be able to call
- // privatisation for looping constructs before the operation is created since
- // the bounds of the MLIR OpenMP operation can be privatised.
- // Step2 performs the copying for lastprivates and requires knowledge of the
- // MLIR operation to insert the last private update. Step2 adds
- // dealocation code as well.
- void processStep1();
- void processStep2(mlir::Operation *op, bool isLoop);
-
- void setLoopIV(mlir::Value iv) {
- assert(!loopIV && "Loop iteration variable already set");
- loopIV = iv;
- }
-};
-
-void DataSharingProcessor::processStep1() {
- collectSymbolsForPrivatization();
- collectDefaultSymbols();
- privatize();
- defaultPrivatize();
- insertBarrier();
-}
-
-void DataSharingProcessor::processStep2(mlir::Operation *op, bool isLoop) {
- insPt = firOpBuilder.saveInsertionPoint();
- copyLastPrivatize(op);
- firOpBuilder.restoreInsertionPoint(insPt);
-
- if (isLoop) {
- // push deallocs out of the loop
- firOpBuilder.setInsertionPointAfter(op);
- insertDeallocs();
- } else {
- // insert dummy instruction to mark the insertion position
- mlir::Value undefMarker = firOpBuilder.create<fir::UndefOp>(
- op->getLoc(), firOpBuilder.getIndexType());
- insertDeallocs();
- firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
- }
-}
-
-void DataSharingProcessor::insertDeallocs() {
- for (const Fortran::semantics::Symbol *sym : privatizedSymbols)
- if (Fortran::semantics::IsAllocatable(sym->GetUltimate())) {
- converter.createHostAssociateVarCloneDealloc(*sym);
- }
-}
-
-void DataSharingProcessor::cloneSymbol(const Fortran::semantics::Symbol *sym) {
- // Privatization for symbols which are pre-determined (like loop index
- // variables) happen separately, for everything else privatize here.
- if (sym->test(Fortran::semantics::Symbol::Flag::OmpPreDetermined))
- return;
- bool success = converter.createHostAssociateVarClone(*sym);
- (void)success;
- assert(success && "Privatization failed due to existing binding");
-}
-
-void DataSharingProcessor::copyFirstPrivateSymbol(
- const Fortran::semantics::Symbol *sym) {
- if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate))
- converter.copyHostAssociateVar(*sym);
-}
-
-void DataSharingProcessor::copyLastPrivateSymbol(
- const Fortran::semantics::Symbol *sym,
- [[maybe_unused]] mlir::OpBuilder::InsertPoint *lastPrivIP) {
- if (sym->test(Fortran::semantics::Symbol::Flag::OmpLastPrivate))
- converter.copyHostAssociateVar(*sym, lastPrivIP);
-}
-
-void DataSharingProcessor::collectOmpObjectListSymbol(
- const Fortran::parser::OmpObjectList &ompObjectList,
- llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet) {
- for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
- Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
- symbolSet.insert(sym);
- }
-}
-
-void DataSharingProcessor::collectSymbolsForPrivatization() {
- bool hasCollapse = false;
- for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
- if (const auto &privateClause =
- std::get_if<Fortran::parser::OmpClause::Private>(&clause.u)) {
- collectOmpObjectListSymbol(privateClause->v, privatizedSymbols);
- } else if (const auto &firstPrivateClause =
- std::get_if<Fortran::parser::OmpClause::Firstprivate>(
- &clause.u)) {
- collectOmpObjectListSymbol(firstPrivateClause->v, privatizedSymbols);
- } else if (const auto &lastPrivateClause =
- std::get_if<Fortran::parser::OmpClause::Lastprivate>(
- &clause.u)) {
- collectOmpObjectListSymbol(lastPrivateClause->v, privatizedSymbols);
- hasLastPrivateOp = true;
- } else if (std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) {
- hasCollapse = true;
- }
- }
-
- if (hasCollapse && hasLastPrivateOp)
- TODO(converter.getCurrentLocation(), "Collapse clause with lastprivate");
-}
-
-bool DataSharingProcessor::needBarrier() {
- for (const Fortran::semantics::Symbol *sym : privatizedSymbols) {
- if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate) &&
- sym->test(Fortran::semantics::Symbol::Flag::OmpLastPrivate))
- return true;
- }
- return false;
-}
-
-void DataSharingProcessor::insertBarrier() {
- // Emit implicit barrier to synchronize threads and avoid data races on
- // initialization of firstprivate variables and post-update of lastprivate
- // variables.
- // FIXME: Emit barrier for lastprivate clause when 'sections' directive has
- // 'nowait' clause. Otherwise, emit barrier when 'sections' directive has
- // both firstprivate and lastprivate clause.
- // Emit implicit barrier for linear clause. Maybe on somewhere else.
- if (needBarrier())
- firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation());
-}
-
-void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
- bool cmpCreated = false;
- mlir::OpBuilder::InsertPoint localInsPt = firOpBuilder.saveInsertionPoint();
- for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
- if (std::get_if<Fortran::parser::OmpClause::Lastprivate>(&clause.u)) {
- // TODO: Add lastprivate support for simd construct
- if (mlir::isa<mlir::omp::SectionOp>(op)) {
- if (&eval == &eval.parentConstruct->getLastNestedEvaluation()) {
- // For `omp.sections`, lastprivatized variables occur in
- // lexically final `omp.section` operation. The following FIR
- // shall be generated for the same:
- //
- // omp.sections lastprivate(...) {
- // omp.section {...}
- // omp.section {...}
- // omp.section {
- // fir.allocate for `private`/`firstprivate`
- // <More operations here>
- // fir.if %true {
- // ^%lpv_update_blk
- // }
- // }
- // }
- //
- // To keep code consistency while handling privatization
- // through this control flow, add a `fir.if` operation
- // that always evaluates to true, in order to create
- // a dedicated sub-region in `omp.section` where
- // lastprivate FIR can reside. Later canonicalizations
- // will optimize away this operation.
- if (!eval.lowerAsUnstructured()) {
- auto ifOp = firOpBuilder.create<fir::IfOp>(
- op->getLoc(),
- firOpBuilder.createIntegerConstant(
- op->getLoc(), firOpBuilder.getIntegerType(1), 0x1),
- /*else*/ false);
- firOpBuilder.setInsertionPointToStart(
- &ifOp.getThenRegion().front());
-
- const Fortran::parser::OpenMPConstruct *parentOmpConstruct =
- eval.parentConstruct->getIf<Fortran::parser::OpenMPConstruct>();
- assert(parentOmpConstruct &&
- "Expected a valid enclosing OpenMP construct");
- const Fortran::parser::OpenMPSectionsConstruct *sectionsConstruct =
- std::get_if<Fortran::parser::OpenMPSectionsConstruct>(
- &parentOmpConstruct->u);
- assert(sectionsConstruct &&
- "Expected an enclosing omp.sections construct");
- const Fortran::parser::OmpClauseList §ionsEndClauseList =
- std::get<Fortran::parser::OmpClauseList>(
- std::get<Fortran::parser::OmpEndSectionsDirective>(
- sectionsConstruct->t)
- .t);
- for (const Fortran::parser::OmpClause &otherClause :
- sectionsEndClauseList.v)
- if (std::get_if<Fortran::parser::OmpClause::Nowait>(
- &otherClause.u))
- // Emit implicit barrier to synchronize threads and avoid data
- // races on post-update of lastprivate variables when `nowait`
- // clause is present.
- firOpBuilder.create<mlir::omp::BarrierOp>(
- converter.getCurrentLocation());
- firOpBuilder.setInsertionPointToStart(
- &ifOp.getThenRegion().front());
- lastPrivIP = firOpBuilder.saveInsertionPoint();
- firOpBuilder.setInsertionPoint(ifOp);
- insPt = firOpBuilder.saveInsertionPoint();
- } else {
- // Lastprivate operation is inserted at the end
- // of the lexically last section in the sections
- // construct
- mlir::OpBuilder::InsertPoint unstructuredSectionsIP =
- firOpBuilder.saveInsertionPoint();
- mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
- firOpBuilder.setInsertionPoint(lastOper);
- lastPrivIP = firOpBuilder.saveInsertionPoint();
- firOpBuilder.restoreInsertionPoint(unstructuredSectionsIP);
- }
- }
- } else if (mlir::isa<mlir::omp::WsLoopOp>(op)) {
- // Update the original variable just before exiting the worksharing
- // loop. Conversion as follows:
- //
- // omp.wsloop {
- // omp.wsloop { ...
- // ... store
- // store ===> %v = arith.addi %iv, %step
- // omp.yield %cmp = %step < 0 ? %v < %ub : %v > %ub
- // } fir.if %cmp {
- // fir.store %v to %loopIV
- // ^%lpv_update_blk:
- // }
- // omp.yield
- // }
- //
-
- // Only generate the compare once in presence of multiple LastPrivate
- // clauses.
- if (cmpCreated)
- continue;
- cmpCreated = true;
-
- mlir::Location loc = op->getLoc();
- mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
- firOpBuilder.setInsertionPoint(lastOper);
-
- mlir::Value iv = op->getRegion(0).front().getArguments()[0];
- mlir::Value ub =
- mlir::dyn_cast<mlir::omp::WsLoopOp>(op).getUpperBound()[0];
- mlir::Value step = mlir::dyn_cast<mlir::omp::WsLoopOp>(op).getStep()[0];
-
- // v = iv + step
- // cmp = step < 0 ? v < ub : v > ub
- mlir::Value v = firOpBuilder.create<mlir::arith::AddIOp>(loc, iv, step);
- mlir::Value zero =
- firOpBuilder.createIntegerConstant(loc, step.getType(), 0);
- mlir::Value negativeStep = firOpBuilder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::slt, step, zero);
- mlir::Value vLT = firOpBuilder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::slt, v, ub);
- mlir::Value vGT = firOpBuilder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::sgt, v, ub);
- mlir::Value cmpOp = firOpBuilder.create<mlir::arith::SelectOp>(
- loc, negativeStep, vLT, vGT);
-
- auto ifOp = firOpBuilder.create<fir::IfOp>(loc, cmpOp, /*else*/ false);
- firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- assert(loopIV && "loopIV was not set");
- firOpBuilder.create<fir::StoreOp>(op->getLoc(), v, loopIV);
- lastPrivIP = firOpBuilder.saveInsertionPoint();
- } else {
- TODO(converter.getCurrentLocation(),
- "lastprivate clause in constructs other than "
- "simd/worksharing-loop");
- }
- }
- }
- firOpBuilder.restoreInsertionPoint(localInsPt);
-}
-
-void DataSharingProcessor::collectSymbols(
- Fortran::semantics::Symbol::Flag flag) {
- converter.collectSymbolSet(eval, defaultSymbols, flag,
- /*collectSymbols=*/true,
- /*collectHostAssociatedSymbols=*/true);
- for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) {
- if (e.hasNestedEvaluations())
- converter.collectSymbolSet(e, symbolsInNestedRegions, flag,
- /*collectSymbols=*/true,
- /*collectHostAssociatedSymbols=*/false);
- else
- converter.collectSymbolSet(e, symbolsInParentRegions, flag,
- /*collectSymbols=*/false,
- /*collectHostAssociatedSymbols=*/true);
- }
-}
-
-void DataSharingProcessor::collectDefaultSymbols() {
- for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
- if (const auto &defaultClause =
- std::get_if<Fortran::parser::OmpClause::Default>(&clause.u)) {
- if (defaultClause->v.v ==
- Fortran::parser::OmpDefaultClause::Type::Private)
- collectSymbols(Fortran::semantics::Symbol::Flag::OmpPrivate);
- else if (defaultClause->v.v ==
- Fortran::parser::OmpDefaultClause::Type::Firstprivate)
- collectSymbols(Fortran::semantics::Symbol::Flag::OmpFirstPrivate);
- }
- }
-}
-
-void DataSharingProcessor::privatize() {
- for (const Fortran::semantics::Symbol *sym : privatizedSymbols) {
- if (const auto *commonDet =
- sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
- for (const auto &mem : commonDet->objects()) {
- cloneSymbol(&*mem);
- copyFirstPrivateSymbol(&*mem);
- }
- } else {
- cloneSymbol(sym);
- copyFirstPrivateSymbol(sym);
- }
- }
-}
-
-void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) {
- insertLastPrivateCompare(op);
- for (const Fortran::semantics::Symbol *sym : privatizedSymbols)
- if (const auto *commonDet =
- sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
- for (const auto &mem : commonDet->objects()) {
- copyLastPrivateSymbol(&*mem, &lastPrivIP);
- }
- } else {
- copyLastPrivateSymbol(sym, &lastPrivIP);
- }
-}
-
-void DataSharingProcessor::defaultPrivatize() {
- for (const Fortran::semantics::Symbol *sym : defaultSymbols) {
- if (!Fortran::semantics::IsProcedure(*sym) &&
- !sym->GetUltimate().has<Fortran::semantics::DerivedTypeDetails>() &&
- !sym->GetUltimate().has<Fortran::semantics::NamelistDetails>() &&
- !symbolsInNestedRegions.contains(sym) &&
- !symbolsInParentRegions.contains(sym) &&
- !privatizedSymbols.contains(sym)) {
- cloneSymbol(sym);
- copyFirstPrivateSymbol(sym);
- }
- }
-}
-
-//===----------------------------------------------------------------------===//
-// ClauseProcessor
-//===----------------------------------------------------------------------===//
-
-/// Class that handles the processing of OpenMP clauses.
-///
-/// Its `process<ClauseName>()` methods perform MLIR code generation for their
-/// corresponding clause if it is present in the clause list. Otherwise, they
-/// will return `false` to signal that the clause was not found.
-///
-/// The intended use is of this class is to move clause processing outside of
-/// construct processing, since the same clauses can appear attached to
-/// different constructs and constructs can be combined, so that code
-/// duplication is minimized.
-///
-/// Each construct-lowering function only calls the `process<ClauseName>()`
-/// methods that relate to clauses that can impact the lowering of that
-/// construct.
-class ClauseProcessor {
- using ClauseTy = Fortran::parser::OmpClause;
-
-public:
- ClauseProcessor(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses)
- : converter(converter), semaCtx(semaCtx), clauses(clauses) {}
-
- // 'Unique' clauses: They can appear at most once in the clause list.
- bool
- processCollapse(mlir::Location currentLocation,
- Fortran::lower::pft::Evaluation &eval,
- llvm::SmallVectorImpl<mlir::Value> &lowerBound,
- llvm::SmallVectorImpl<mlir::Value> &upperBound,
- llvm::SmallVectorImpl<mlir::Value> &step,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
- std::size_t &loopVarTypeSize) const;
- bool processDefault() const;
- bool processDevice(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const;
- bool processFinal(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processHint(mlir::IntegerAttr &result) const;
- bool processMergeable(mlir::UnitAttr &result) const;
- bool processNowait(mlir::UnitAttr &result) const;
- bool processNumTeams(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processNumThreads(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processOrdered(mlir::IntegerAttr &result) const;
- bool processPriority(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const;
- bool processSafelen(mlir::IntegerAttr &result) const;
- bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr,
- mlir::omp::ScheduleModifierAttr &modifierAttr,
- mlir::UnitAttr &simdModifierAttr) const;
- bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processSimdlen(mlir::IntegerAttr &result) const;
- bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processUntied(mlir::UnitAttr &result) const;
-
- // 'Repeatable' clauses: They can appear multiple times in the clause list.
- bool
- processAllocate(llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
- llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const;
- bool processCopyin() const;
- bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
- llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
- bool
- processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
- bool
- processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
- mlir::Value &result) const;
- bool
- processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
-
- // This method is used to process a map clause.
- // The optional parameters - mapSymTypes, mapSymLocs & mapSymbols are used to
- // store the original type, location and Fortran symbol for the map operands.
- // They may be used later on to create the block_arguments for some of the
- // target directives that require it.
- bool processMap(mlir::Location currentLocation,
- const llvm::omp::Directive &directive,
- Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands,
- llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
- llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- *mapSymbols = nullptr) const;
- bool
- processReduction(mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- *reductionSymbols = nullptr) const;
- bool processSectionsReduction(mlir::Location currentLocation) const;
- bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
- bool
- processUseDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
- llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- &useDeviceSymbols) const;
- bool
- processUseDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
- llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- &useDeviceSymbols) const;
-
- template <typename T>
- bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands);
-
- // Call this method for these clauses that should be supported but are not
- // implemented yet. It triggers a compilation error if any of the given
- // clauses is found.
- template <typename... Ts>
- void processTODO(mlir::Location currentLocation,
- llvm::omp::Directive directive) const;
-
-private:
- using ClauseIterator = std::list<ClauseTy>::const_iterator;
-
- /// Utility to find a clause within a range in the clause list.
- template <typename T>
- static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end) {
- for (ClauseIterator it = begin; it != end; ++it) {
- if (std::get_if<T>(&it->u))
- return it;
- }
-
- return end;
- }
-
- /// Return the first instance of the given clause found in the clause list or
- /// `nullptr` if not present. If more than one instance is expected, use
- /// `findRepeatableClause` instead.
- template <typename T>
- const T *
- findUniqueClause(const Fortran::parser::CharBlock **source = nullptr) const {
- ClauseIterator it = findClause<T>(clauses.v.begin(), clauses.v.end());
- if (it != clauses.v.end()) {
- if (source)
- *source = &it->source;
- return &std::get<T>(it->u);
- }
- return nullptr;
- }
-
- /// Call `callbackFn` for each occurrence of the given clause. Return `true`
- /// if at least one instance was found.
- template <typename T>
- bool findRepeatableClause(
- std::function<void(const T *, const Fortran::parser::CharBlock &source)>
- callbackFn) const {
- bool found = false;
- ClauseIterator nextIt, endIt = clauses.v.end();
- for (ClauseIterator it = clauses.v.begin(); it != endIt; it = nextIt) {
- nextIt = findClause<T>(it, endIt);
-
- if (nextIt != endIt) {
- callbackFn(&std::get<T>(nextIt->u), nextIt->source);
- found = true;
- ++nextIt;
- }
- }
- return found;
- }
-
- /// Set the `result` to a new `mlir::UnitAttr` if the clause is present.
- template <typename T>
- bool markClauseOccurrence(mlir::UnitAttr &result) const {
- if (findUniqueClause<T>()) {
- result = converter.getFirOpBuilder().getUnitAttr();
- return true;
- }
- return false;
- }
-
- Fortran::lower::AbstractConverter &converter;
- Fortran::semantics::SemanticsContext &semaCtx;
- const Fortran::parser::OmpClauseList &clauses;
-};
-
-//===----------------------------------------------------------------------===//
-// ClauseProcessor helper functions
-//===----------------------------------------------------------------------===//
-
-/// Check for unsupported map operand types.
-static void checkMapType(mlir::Location location, mlir::Type type) {
- if (auto refType = type.dyn_cast<fir::ReferenceType>())
- type = refType.getElementType();
- if (auto boxType = type.dyn_cast_or_null<fir::BoxType>())
- if (!boxType.getElementType().isa<fir::PointerType>())
- TODO(location, "OMPD_target_data MapOperand BoxType");
-}
-
-class ReductionProcessor {
-public:
- // TODO: Move this enumeration to the OpenMP dialect
- enum ReductionIdentifier {
- ID,
- USER_DEF_OP,
- ADD,
- SUBTRACT,
- MULTIPLY,
- AND,
- OR,
- EQV,
- NEQV,
- MAX,
- MIN,
- IAND,
- IOR,
- IEOR
- };
- static ReductionIdentifier
- getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
- auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
- getRealName(pd).ToString())
- .Case("max", ReductionIdentifier::MAX)
- .Case("min", ReductionIdentifier::MIN)
- .Case("iand", ReductionIdentifier::IAND)
- .Case("ior", ReductionIdentifier::IOR)
- .Case("ieor", ReductionIdentifier::IEOR)
- .Default(std::nullopt);
- assert(redType && "Invalid Reduction");
- return *redType;
- }
-
- static ReductionIdentifier getReductionType(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) {
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- return ReductionIdentifier::ADD;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract:
- return ReductionIdentifier::SUBTRACT;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- return ReductionIdentifier::MULTIPLY;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
- return ReductionIdentifier::AND;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
- return ReductionIdentifier::EQV;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
- return ReductionIdentifier::OR;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
- return ReductionIdentifier::NEQV;
- default:
- llvm_unreachable("unexpected intrinsic operator in reduction");
- }
- }
-
- static bool supportedIntrinsicProcReduction(
- const Fortran::parser::ProcedureDesignator &pd) {
- const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
- assert(name && "Invalid Reduction Intrinsic.");
- if (!name->symbol->GetUltimate().attrs().test(
- Fortran::semantics::Attr::INTRINSIC))
- return false;
- auto redType = llvm::StringSwitch<bool>(getRealName(name).ToString())
- .Case("max", true)
- .Case("min", true)
- .Case("iand", true)
- .Case("ior", true)
- .Case("ieor", true)
- .Default(false);
- return redType;
- }
-
- static const Fortran::semantics::SourceName
- getRealName(const Fortran::parser::Name *name) {
- return name->symbol->GetUltimate().name();
- }
-
- static const Fortran::semantics::SourceName
- getRealName(const Fortran::parser::ProcedureDesignator &pd) {
- const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
- assert(name && "Invalid Reduction Intrinsic.");
- return getRealName(name);
- }
-
- static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
- return (llvm::Twine(name) +
- (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
- llvm::Twine(ty.getIntOrFloatBitWidth()))
- .str();
- }
-
- static std::string getReductionName(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type ty) {
- std::string reductionName;
-
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- reductionName = "add_reduction";
- break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- reductionName = "multiply_reduction";
- break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
- return "and_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
- return "eqv_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
- return "or_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
- return "neqv_reduction";
- default:
- reductionName = "other_reduction";
- break;
- }
-
- return getReductionName(reductionName, ty);
- }
-
- /// This function returns the identity value of the operator \p
- /// reductionOpName. For example:
- /// 0 + x = x,
- /// 1 * x = x
- static int getOperationIdentity(ReductionIdentifier redId,
- mlir::Location loc) {
- switch (redId) {
- case ReductionIdentifier::ADD:
- case ReductionIdentifier::OR:
- case ReductionIdentifier::NEQV:
- return 0;
- case ReductionIdentifier::MULTIPLY:
- case ReductionIdentifier::AND:
- case ReductionIdentifier::EQV:
- return 1;
- default:
- TODO(loc, "Reduction of some intrinsic operators is not supported");
- }
- }
-
- static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
- ReductionIdentifier redId,
- fir::FirOpBuilder &builder) {
- assert((fir::isa_integer(type) || fir::isa_real(type) ||
- type.isa<fir::LogicalType>()) &&
- "only integer, logical and real types are currently supported");
- switch (redId) {
- case ReductionIdentifier::MAX: {
- if (auto ty = type.dyn_cast<mlir::FloatType>()) {
- const llvm::fltSemantics &sem = ty.getFloatSemantics();
- return builder.createRealConstant(
- loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
- }
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, minInt);
- }
- case ReductionIdentifier::MIN: {
- if (auto ty = type.dyn_cast<mlir::FloatType>()) {
- const llvm::fltSemantics &sem = ty.getFloatSemantics();
- return builder.createRealConstant(
- loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
- }
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, maxInt);
- }
- case ReductionIdentifier::IOR: {
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, zeroInt);
- }
- case ReductionIdentifier::IEOR: {
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, zeroInt);
- }
- case ReductionIdentifier::IAND: {
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, allOnInt);
- }
- case ReductionIdentifier::ADD:
- case ReductionIdentifier::MULTIPLY:
- case ReductionIdentifier::AND:
- case ReductionIdentifier::OR:
- case ReductionIdentifier::EQV:
- case ReductionIdentifier::NEQV:
- if (type.isa<mlir::FloatType>())
- return builder.create<mlir::arith::ConstantOp>(
- loc, type,
- builder.getFloatAttr(type,
- (double)getOperationIdentity(redId, loc)));
-
- if (type.isa<fir::LogicalType>()) {
- mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
- loc, builder.getI1Type(),
- builder.getIntegerAttr(builder.getI1Type(),
- getOperationIdentity(redId, loc)));
- return builder.createConvert(loc, type, intConst);
- }
-
- return builder.create<mlir::arith::ConstantOp>(
- loc, type,
- builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
- case ReductionIdentifier::ID:
- case ReductionIdentifier::USER_DEF_OP:
- case ReductionIdentifier::SUBTRACT:
- TODO(loc, "Reduction of some identifier types is not supported");
- }
- llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
- }
-
- template <typename FloatOp, typename IntegerOp>
- static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
- mlir::Type type, mlir::Location loc,
- mlir::Value op1, mlir::Value op2) {
- assert(type.isIntOrIndexOrFloat() &&
- "only integer and float types are currently supported");
- if (type.isIntOrIndex())
- return builder.create<IntegerOp>(loc, op1, op2);
- return builder.create<FloatOp>(loc, op1, op2);
- }
-
- static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder,
- mlir::Location loc,
- ReductionIdentifier redId,
- mlir::Type type, mlir::Value op1,
- mlir::Value op2) {
- mlir::Value reductionOp;
- switch (redId) {
- case ReductionIdentifier::MAX:
- reductionOp =
- getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
- builder, type, loc, op1, op2);
- break;
- case ReductionIdentifier::MIN:
- reductionOp =
- getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
- builder, type, loc, op1, op2);
- break;
- case ReductionIdentifier::IOR:
- assert((type.isIntOrIndex()) && "only integer is expected");
- reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
- break;
- case ReductionIdentifier::IEOR:
- assert((type.isIntOrIndex()) && "only integer is expected");
- reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
- break;
- case ReductionIdentifier::IAND:
- assert((type.isIntOrIndex()) && "only integer is expected");
- reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
- break;
- case ReductionIdentifier::ADD:
- reductionOp =
- getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
- builder, type, loc, op1, op2);
- break;
- case ReductionIdentifier::MULTIPLY:
- reductionOp =
- getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
- builder, type, loc, op1, op2);
- break;
- case ReductionIdentifier::AND: {
- mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
- mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
-
- mlir::Value andiOp =
- builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
-
- reductionOp = builder.createConvert(loc, type, andiOp);
- break;
- }
- case ReductionIdentifier::OR: {
- mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
- mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
-
- mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
-
- reductionOp = builder.createConvert(loc, type, oriOp);
- break;
- }
- case ReductionIdentifier::EQV: {
- mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
- mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
-
- mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
-
- reductionOp = builder.createConvert(loc, type, cmpiOp);
- break;
- }
- case ReductionIdentifier::NEQV: {
- mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
- mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
-
- mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
-
- reductionOp = builder.createConvert(loc, type, cmpiOp);
- break;
- }
- default:
- TODO(loc, "Reduction of some intrinsic operators is not supported");
- }
-
- return reductionOp;
- }
-
- /// Creates an OpenMP reduction declaration and inserts it into the provided
- /// symbol table. The declaration has a constant initializer with the neutral
- /// value `initValue`, and the reduction combiner carried over from `reduce`.
- /// TODO: Generalize this for non-integer types, add atomic region.
- static mlir::omp::ReductionDeclareOp createReductionDecl(
- fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
- const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
- mlir::OpBuilder::InsertionGuard guard(builder);
- mlir::ModuleOp module = builder.getModule();
-
- auto decl =
- module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
- if (decl)
- return decl;
-
- mlir::OpBuilder modBuilder(module.getBodyRegion());
-
- decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
- loc, reductionOpName, type);
- builder.createBlock(&decl.getInitializerRegion(),
- decl.getInitializerRegion().end(), {type}, {loc});
- builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
- mlir::Value init = getReductionInitValue(loc, type, redId, builder);
- builder.create<mlir::omp::YieldOp>(loc, init);
-
- builder.createBlock(&decl.getReductionRegion(),
- decl.getReductionRegion().end(), {type, type},
- {loc, loc});
-
- builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
- mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
- mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
-
- mlir::Value reductionOp =
- createScalarCombiner(builder, loc, redId, type, op1, op2);
- builder.create<mlir::omp::YieldOp>(loc, reductionOp);
-
- return decl;
- }
-
- /// Creates a reduction declaration and associates it with an OpenMP block
- /// directive.
- static void
- addReductionDecl(mlir::Location currentLocation,
- Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpReductionClause &reduction,
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- *reductionSymbols = nullptr) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::omp::ReductionDeclareOp decl;
- const auto &redOperator{
- std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
- const auto &objectList{
- std::get<Fortran::parser::OmpObjectList>(reduction.t)};
- if (const auto &redDefinedOp =
- std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
- const auto &intrinsicOp{
- std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
- redDefinedOp->u)};
- ReductionIdentifier redId = getReductionType(intrinsicOp);
- switch (redId) {
- case ReductionIdentifier::ADD:
- case ReductionIdentifier::MULTIPLY:
- case ReductionIdentifier::AND:
- case ReductionIdentifier::EQV:
- case ReductionIdentifier::OR:
- case ReductionIdentifier::NEQV:
- break;
- default:
- TODO(currentLocation,
- "Reduction of some intrinsic operators is not supported");
- break;
- }
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- if (reductionSymbols)
- reductionSymbols->push_back(symbol);
- mlir::Value symVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
- symVal = declOp.getBase();
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- if (redType.isa<fir::LogicalType>())
- decl = createReductionDecl(
- firOpBuilder,
- getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
- redId, redType, currentLocation);
- else if (redType.isIntOrIndexOrFloat()) {
- decl = createReductionDecl(firOpBuilder,
- getReductionName(intrinsicOp, redType),
- redId, redType, currentLocation);
- } else {
- TODO(currentLocation, "Reduction of some types is not supported");
- }
- reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
- firOpBuilder.getContext(), decl.getSymName()));
- }
- }
- }
- } else if (const auto *reductionIntrinsic =
- std::get_if<Fortran::parser::ProcedureDesignator>(
- &redOperator.u)) {
- if (ReductionProcessor::supportedIntrinsicProcReduction(
- *reductionIntrinsic)) {
- ReductionProcessor::ReductionIdentifier redId =
- ReductionProcessor::getReductionType(*reductionIntrinsic);
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- if (reductionSymbols)
- reductionSymbols->push_back(symbol);
- mlir::Value symVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
- symVal = declOp.getBase();
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- assert(redType.isIntOrIndexOrFloat() &&
- "Unsupported reduction type");
- decl = createReductionDecl(
- firOpBuilder,
- getReductionName(getRealName(*reductionIntrinsic).ToString(),
- redType),
- redId, redType, currentLocation);
- reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
- firOpBuilder.getContext(), decl.getSymName()));
- }
- }
- }
- }
- }
- }
-};
-
-static mlir::omp::ScheduleModifier
-translateScheduleModifier(const Fortran::parser::OmpScheduleModifierType &m) {
- switch (m.v) {
- case Fortran::parser::OmpScheduleModifierType::ModType::Monotonic:
- return mlir::omp::ScheduleModifier::monotonic;
- case Fortran::parser::OmpScheduleModifierType::ModType::Nonmonotonic:
- return mlir::omp::ScheduleModifier::nonmonotonic;
- case Fortran::parser::OmpScheduleModifierType::ModType::Simd:
- return mlir::omp::ScheduleModifier::simd;
- }
- return mlir::omp::ScheduleModifier::none;
-}
-
-static mlir::omp::ScheduleModifier
-getScheduleModifier(const Fortran::parser::OmpScheduleClause &x) {
- const auto &modifier =
- std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t);
- // The input may have the modifier any order, so we look for one that isn't
- // SIMD. If modifier is not set at all, fall down to the bottom and return
- // "none".
- if (modifier) {
- const auto &modType1 =
- std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t);
- if (modType1.v.v ==
- Fortran::parser::OmpScheduleModifierType::ModType::Simd) {
- const auto &modType2 = std::get<
- std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>(
- modifier->t);
- if (modType2 &&
- modType2->v.v !=
- Fortran::parser::OmpScheduleModifierType::ModType::Simd)
- return translateScheduleModifier(modType2->v);
-
- return mlir::omp::ScheduleModifier::none;
- }
-
- return translateScheduleModifier(modType1.v);
- }
- return mlir::omp::ScheduleModifier::none;
-}
-
-static mlir::omp::ScheduleModifier
-getSimdModifier(const Fortran::parser::OmpScheduleClause &x) {
- const auto &modifier =
- std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t);
- // Either of the two possible modifiers in the input can be the SIMD modifier,
- // so look in either one, and return simd if we find one. Not found = return
- // "none".
- if (modifier) {
- const auto &modType1 =
- std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t);
- if (modType1.v.v == Fortran::parser::OmpScheduleModifierType::ModType::Simd)
- return mlir::omp::ScheduleModifier::simd;
-
- const auto &modType2 = std::get<
- std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>(
- modifier->t);
- if (modType2 && modType2->v.v ==
- Fortran::parser::OmpScheduleModifierType::ModType::Simd)
- return mlir::omp::ScheduleModifier::simd;
- }
- return mlir::omp::ScheduleModifier::none;
-}
-
-static void
-genAllocateClause(Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpAllocateClause &ompAllocateClause,
- llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
- llvm::SmallVectorImpl<mlir::Value> &allocateOperands) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::Location currentLocation = converter.getCurrentLocation();
- Fortran::lower::StatementContext stmtCtx;
-
- mlir::Value allocatorOperand;
- const Fortran::parser::OmpObjectList &ompObjectList =
- std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t);
- const auto &allocateModifier = std::get<
- std::optional<Fortran::parser::OmpAllocateClause::AllocateModifier>>(
- ompAllocateClause.t);
-
- // If the allocate modifier is present, check if we only use the allocator
- // submodifier. ALIGN in this context is unimplemented
- const bool onlyAllocator =
- allocateModifier &&
- std::holds_alternative<
- Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
- allocateModifier->u);
-
- if (allocateModifier && !onlyAllocator) {
- TODO(currentLocation, "OmpAllocateClause ALIGN modifier");
- }
-
- // Check if allocate clause has allocator specified. If so, add it
- // to list of allocators, otherwise, add default allocator to
- // list of allocators.
- if (onlyAllocator) {
- const auto &allocatorValue = std::get<
- Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
- allocateModifier->u);
- allocatorOperand = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(allocatorValue.v), stmtCtx));
- allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
- allocatorOperand);
- } else {
- allocatorOperand = firOpBuilder.createIntegerConstant(
- currentLocation, firOpBuilder.getI32Type(), 1);
- allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
- allocatorOperand);
- }
- genObjectList(ompObjectList, converter, allocateOperands);
-}
-
-static mlir::omp::ClauseProcBindKindAttr genProcBindKindAttr(
- fir::FirOpBuilder &firOpBuilder,
- const Fortran::parser::OmpClause::ProcBind *procBindClause) {
- mlir::omp::ClauseProcBindKind procBindKind;
- switch (procBindClause->v.v) {
- case Fortran::parser::OmpProcBindClause::Type::Master:
- procBindKind = mlir::omp::ClauseProcBindKind::Master;
- break;
- case Fortran::parser::OmpProcBindClause::Type::Close:
- procBindKind = mlir::omp::ClauseProcBindKind::Close;
- break;
- case Fortran::parser::OmpProcBindClause::Type::Spread:
- procBindKind = mlir::omp::ClauseProcBindKind::Spread;
- break;
- case Fortran::parser::OmpProcBindClause::Type::Primary:
- procBindKind = mlir::omp::ClauseProcBindKind::Primary;
- break;
- }
- return mlir::omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(),
- procBindKind);
-}
-
-static mlir::omp::ClauseTaskDependAttr
-genDependKindAttr(fir::FirOpBuilder &firOpBuilder,
- const Fortran::parser::OmpClause::Depend *dependClause) {
- mlir::omp::ClauseTaskDepend pbKind;
- switch (
- std::get<Fortran::parser::OmpDependenceType>(
- std::get<Fortran::parser::OmpDependClause::InOut>(dependClause->v.u)
- .t)
- .v) {
- case Fortran::parser::OmpDependenceType::Type::In:
- pbKind = mlir::omp::ClauseTaskDepend::taskdependin;
- break;
- case Fortran::parser::OmpDependenceType::Type::Out:
- pbKind = mlir::omp::ClauseTaskDepend::taskdependout;
- break;
- case Fortran::parser::OmpDependenceType::Type::Inout:
- pbKind = mlir::omp::ClauseTaskDepend::taskdependinout;
- break;
- default:
- llvm_unreachable("unknown parser task dependence type");
- break;
- }
- return mlir::omp::ClauseTaskDependAttr::get(firOpBuilder.getContext(),
- pbKind);
-}
-
-static mlir::Value getIfClauseOperand(
- Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpClause::If *ifClause,
- Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
- mlir::Location clauseLocation) {
- // Only consider the clause if it's intended for the given directive.
- auto &directive = std::get<
- std::optional<Fortran::parser::OmpIfClause::DirectiveNameModifier>>(
- ifClause->v.t);
- if (directive && directive.value() != directiveName)
- return nullptr;
-
- Fortran::lower::StatementContext stmtCtx;
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
- mlir::Value ifVal = fir::getBase(
- converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
- return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(),
- ifVal);
-}
-
-static void
-addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpObjectList &useDeviceClause,
- llvm::SmallVectorImpl<mlir::Value> &operands,
- llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- &useDeviceSymbols) {
- genObjectList(useDeviceClause, converter, operands);
- for (mlir::Value &operand : operands) {
- checkMapType(operand.getLoc(), operand.getType());
- useDeviceTypes.push_back(operand.getType());
- useDeviceLocs.push_back(operand.getLoc());
- }
- for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) {
- Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
- useDeviceSymbols.push_back(sym);
- }
-}
-
-//===----------------------------------------------------------------------===//
-// ClauseProcessor unique clauses
-//===----------------------------------------------------------------------===//
-
-bool ClauseProcessor::processCollapse(
- mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
- llvm::SmallVectorImpl<mlir::Value> &lowerBound,
- llvm::SmallVectorImpl<mlir::Value> &upperBound,
- llvm::SmallVectorImpl<mlir::Value> &step,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
- std::size_t &loopVarTypeSize) const {
- bool found = false;
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
- // Collect the loops to collapse.
- Fortran::lower::pft::Evaluation *doConstructEval =
- &eval.getFirstNestedEvaluation();
- if (doConstructEval->getIf<Fortran::parser::DoConstruct>()
- ->IsDoConcurrent()) {
- TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
- }
-
- std::int64_t collapseValue = 1l;
- if (auto *collapseClause = findUniqueClause<ClauseTy::Collapse>()) {
- const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
- collapseValue = Fortran::evaluate::ToInt64(*expr).value();
- found = true;
- }
-
- loopVarTypeSize = 0;
- do {
- Fortran::lower::pft::Evaluation *doLoop =
- &doConstructEval->getFirstNestedEvaluation();
- auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
- assert(doStmt && "Expected do loop to be in the nested evaluation");
- const auto &loopControl =
- std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
- const Fortran::parser::LoopControl::Bounds *bounds =
- std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
- assert(bounds && "Expected bounds for worksharing do loop");
- Fortran::lower::StatementContext stmtCtx;
- lowerBound.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
- upperBound.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
- if (bounds->step) {
- step.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
- } else { // If `step` is not present, assume it as `1`.
- step.push_back(firOpBuilder.createIntegerConstant(
- currentLocation, firOpBuilder.getIntegerType(32), 1));
- }
- iv.push_back(bounds->name.thing.symbol);
- loopVarTypeSize = std::max(loopVarTypeSize,
- bounds->name.thing.symbol->GetUltimate().size());
- collapseValue--;
- doConstructEval =
- &*std::next(doConstructEval->getNestedEvaluations().begin());
- } while (collapseValue > 0);
-
- return found;
-}
-
-bool ClauseProcessor::processDefault() const {
- if (auto *defaultClause = findUniqueClause<ClauseTy::Default>()) {
- // Private, Firstprivate, Shared, None
- switch (defaultClause->v.v) {
- case Fortran::parser::OmpDefaultClause::Type::Shared:
- case Fortran::parser::OmpDefaultClause::Type::None:
- // Default clause with shared or none do not require any handling since
- // Shared is the default behavior in the IR and None is only required
- // for semantic checks.
- break;
- case Fortran::parser::OmpDefaultClause::Type::Private:
- // TODO Support default(private)
- break;
- case Fortran::parser::OmpDefaultClause::Type::Firstprivate:
- // TODO Support default(firstprivate)
- break;
- }
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
- const Fortran::parser::CharBlock *source = nullptr;
- if (auto *deviceClause = findUniqueClause<ClauseTy::Device>(&source)) {
- mlir::Location clauseLocation = converter.genLocation(*source);
- if (auto deviceModifier = std::get<
- std::optional<Fortran::parser::OmpDeviceClause::DeviceModifier>>(
- deviceClause->v.t)) {
- if (deviceModifier ==
- Fortran::parser::OmpDeviceClause::DeviceModifier::Ancestor) {
- TODO(clauseLocation, "OMPD_target Device Modifier Ancestor");
- }
- }
- if (const auto *deviceExpr = Fortran::semantics::GetExpr(
- std::get<Fortran::parser::ScalarIntExpr>(deviceClause->v.t))) {
- result = fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx));
- }
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processDeviceType(
- mlir::omp::DeclareTargetDeviceType &result) const {
- if (auto *deviceTypeClause = findUniqueClause<ClauseTy::DeviceType>()) {
- // Case: declare target ... device_type(any | host | nohost)
- switch (deviceTypeClause->v.v) {
- case Fortran::parser::OmpDeviceTypeClause::Type::Nohost:
- result = mlir::omp::DeclareTargetDeviceType::nohost;
- break;
- case Fortran::parser::OmpDeviceTypeClause::Type::Host:
- result = mlir::omp::DeclareTargetDeviceType::host;
- break;
- case Fortran::parser::OmpDeviceTypeClause::Type::Any:
- result = mlir::omp::DeclareTargetDeviceType::any;
- break;
- }
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
- const Fortran::parser::CharBlock *source = nullptr;
- if (auto *finalClause = findUniqueClause<ClauseTy::Final>(&source)) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::Location clauseLocation = converter.genLocation(*source);
-
- mlir::Value finalVal = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(finalClause->v), stmtCtx));
- result = firOpBuilder.createConvert(clauseLocation,
- firOpBuilder.getI1Type(), finalVal);
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const {
- if (auto *hintClause = findUniqueClause<ClauseTy::Hint>()) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- const auto *expr = Fortran::semantics::GetExpr(hintClause->v);
- int64_t hintValue = *Fortran::evaluate::ToInt64(*expr);
- result = firOpBuilder.getI64IntegerAttr(hintValue);
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processMergeable(mlir::UnitAttr &result) const {
- return markClauseOccurrence<ClauseTy::Mergeable>(result);
-}
-
-bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const {
- return markClauseOccurrence<ClauseTy::Nowait>(result);
-}
-
-bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
- // TODO Get lower and upper bounds for num_teams when parser is updated to
- // accept both.
- if (auto *numTeamsClause = findUniqueClause<ClauseTy::NumTeams>()) {
- result = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(numTeamsClause->v), stmtCtx));
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processNumThreads(
- Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
- if (auto *numThreadsClause = findUniqueClause<ClauseTy::NumThreads>()) {
- // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
- result = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const {
- if (auto *orderedClause = findUniqueClause<ClauseTy::Ordered>()) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- int64_t orderedClauseValue = 0l;
- if (orderedClause->v.has_value()) {
- const auto *expr = Fortran::semantics::GetExpr(orderedClause->v);
- orderedClauseValue = *Fortran::evaluate::ToInt64(*expr);
- }
- result = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
- if (auto *priorityClause = findUniqueClause<ClauseTy::Priority>()) {
- result = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx));
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processProcBind(
- mlir::omp::ClauseProcBindKindAttr &result) const {
- if (auto *procBindClause = findUniqueClause<ClauseTy::ProcBind>()) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- result = genProcBindKindAttr(firOpBuilder, procBindClause);
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processSafelen(mlir::IntegerAttr &result) const {
- if (auto *safelenClause = findUniqueClause<ClauseTy::Safelen>()) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- const auto *expr = Fortran::semantics::GetExpr(safelenClause->v);
- const std::optional<std::int64_t> safelenVal =
- Fortran::evaluate::ToInt64(*expr);
- result = firOpBuilder.getI64IntegerAttr(*safelenVal);
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processSchedule(
- mlir::omp::ClauseScheduleKindAttr &valAttr,
- mlir::omp::ScheduleModifierAttr &modifierAttr,
- mlir::UnitAttr &simdModifierAttr) const {
- if (auto *scheduleClause = findUniqueClause<ClauseTy::Schedule>()) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::MLIRContext *context = firOpBuilder.getContext();
- const Fortran::parser::OmpScheduleClause &scheduleType = scheduleClause->v;
- const auto &scheduleClauseKind =
- std::get<Fortran::parser::OmpScheduleClause::ScheduleType>(
- scheduleType.t);
-
- mlir::omp::ClauseScheduleKind scheduleKind;
- switch (scheduleClauseKind) {
- case Fortran::parser::OmpScheduleClause::ScheduleType::Static:
- scheduleKind = mlir::omp::ClauseScheduleKind::Static;
- break;
- case Fortran::parser::OmpScheduleClause::ScheduleType::Dynamic:
- scheduleKind = mlir::omp::ClauseScheduleKind::Dynamic;
- break;
- case Fortran::parser::OmpScheduleClause::ScheduleType::Guided:
- scheduleKind = mlir::omp::ClauseScheduleKind::Guided;
- break;
- case Fortran::parser::OmpScheduleClause::ScheduleType::Auto:
- scheduleKind = mlir::omp::ClauseScheduleKind::Auto;
- break;
- case Fortran::parser::OmpScheduleClause::ScheduleType::Runtime:
- scheduleKind = mlir::omp::ClauseScheduleKind::Runtime;
- break;
- }
-
- mlir::omp::ScheduleModifier scheduleModifier =
- getScheduleModifier(scheduleClause->v);
-
- if (scheduleModifier != mlir::omp::ScheduleModifier::none)
- modifierAttr =
- mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier);
-
- if (getSimdModifier(scheduleClause->v) != mlir::omp::ScheduleModifier::none)
- simdModifierAttr = firOpBuilder.getUnitAttr();
-
- valAttr = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processScheduleChunk(
- Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
- if (auto *scheduleClause = findUniqueClause<ClauseTy::Schedule>()) {
- if (const auto &chunkExpr =
- std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
- scheduleClause->v.t)) {
- if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) {
- result = fir::getBase(converter.genExprValue(*expr, stmtCtx));
- }
- }
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const {
- if (auto *simdlenClause = findUniqueClause<ClauseTy::Simdlen>()) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- const auto *expr = Fortran::semantics::GetExpr(simdlenClause->v);
- const std::optional<std::int64_t> simdlenVal =
- Fortran::evaluate::ToInt64(*expr);
- result = firOpBuilder.getI64IntegerAttr(*simdlenVal);
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processThreadLimit(
- Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
- if (auto *threadLmtClause = findUniqueClause<ClauseTy::ThreadLimit>()) {
- result = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx));
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
- return markClauseOccurrence<ClauseTy::Untied>(result);
-}
-
-//===----------------------------------------------------------------------===//
-// ClauseProcessor repeatable clauses
-//===----------------------------------------------------------------------===//
-
-bool ClauseProcessor::processAllocate(
- llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
- llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const {
- return findRepeatableClause<ClauseTy::Allocate>(
- [&](const ClauseTy::Allocate *allocateClause,
- const Fortran::parser::CharBlock &) {
- genAllocateClause(converter, allocateClause->v, allocatorOperands,
- allocateOperands);
- });
-}
-
-bool ClauseProcessor::processCopyin() const {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint();
- firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
- auto checkAndCopyHostAssociateVar =
- [&](Fortran::semantics::Symbol *sym,
- mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) {
- assert(sym->has<Fortran::semantics::HostAssocDetails>() &&
- "No host-association found");
- if (converter.isPresentShallowLookup(*sym))
- converter.copyHostAssociateVar(*sym, copyAssignIP);
- };
- bool hasCopyin = findRepeatableClause<ClauseTy::Copyin>(
- [&](const ClauseTy::Copyin *copyinClause,
- const Fortran::parser::CharBlock &) {
- const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v;
- for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
- Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
- if (const auto *commonDetails =
- sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
- for (const auto &mem : commonDetails->objects())
- checkAndCopyHostAssociateVar(&*mem, &insPt);
- break;
- }
- if (Fortran::semantics::IsAllocatableOrObjectPointer(
- &sym->GetUltimate()))
- TODO(converter.getCurrentLocation(),
- "pointer or allocatable variables in Copyin clause");
- assert(sym->has<Fortran::semantics::HostAssocDetails>() &&
- "No host-association found");
- checkAndCopyHostAssociateVar(sym);
- }
- });
-
- // [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to
- // the execution of the associated structured block. Emit implicit barrier to
- // synchronize threads and avoid data races on propagation master's thread
- // values of threadprivate variables to local instances of that variables of
- // all other implicit threads.
- if (hasCopyin)
- firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation());
- firOpBuilder.restoreInsertionPoint(insPt);
- return hasCopyin;
-}
-
-bool ClauseProcessor::processDepend(
- llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
- llvm::SmallVectorImpl<mlir::Value> &dependOperands) const {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
- return findRepeatableClause<ClauseTy::Depend>(
- [&](const ClauseTy::Depend *dependClause,
- const Fortran::parser::CharBlock &) {
- const std::list<Fortran::parser::Designator> &depVal =
- std::get<std::list<Fortran::parser::Designator>>(
- std::get<Fortran::parser::OmpDependClause::InOut>(
- dependClause->v.u)
- .t);
- mlir::omp::ClauseTaskDependAttr dependTypeOperand =
- genDependKindAttr(firOpBuilder, dependClause);
- dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(),
- dependTypeOperand);
- for (const Fortran::parser::Designator &ompObject : depVal) {
- Fortran::semantics::Symbol *sym = nullptr;
- std::visit(
- Fortran::common::visitors{
- [&](const Fortran::parser::DataRef &designator) {
- if (const Fortran::parser::Name *name =
- std::get_if<Fortran::parser::Name>(&designator.u)) {
- sym = name->symbol;
- } else if (std::get_if<Fortran::common::Indirection<
- Fortran::parser::ArrayElement>>(
- &designator.u)) {
- TODO(converter.getCurrentLocation(),
- "array sections not supported for task depend");
- }
- },
- [&](const Fortran::parser::Substring &designator) {
- TODO(converter.getCurrentLocation(),
- "substring not supported for task depend");
- }},
- (ompObject).u);
- const mlir::Value variable = converter.getSymbolAddress(*sym);
- dependOperands.push_back(variable);
- }
- });
-}
-
-bool ClauseProcessor::processIf(
- Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
- mlir::Value &result) const {
- bool found = false;
- findRepeatableClause<ClauseTy::If>(
- [&](const ClauseTy::If *ifClause,
- const Fortran::parser::CharBlock &source) {
- mlir::Location clauseLocation = converter.genLocation(source);
- mlir::Value operand = getIfClauseOperand(converter, ifClause,
- directiveName, clauseLocation);
- // Assume that, at most, a single 'if' clause will be applicable to the
- // given directive.
- if (operand) {
- result = operand;
- found = true;
- }
- });
- return found;
-}
-
-bool ClauseProcessor::processLink(
- llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
- return findRepeatableClause<ClauseTy::Link>(
- [&](const ClauseTy::Link *linkClause,
- const Fortran::parser::CharBlock &) {
- // Case: declare target link(var1, var2)...
- gatherFuncAndVarSyms(
- linkClause->v, mlir::omp::DeclareTargetCaptureClause::link, result);
- });
-}
-
-static mlir::omp::MapInfoOp
-createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
- mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
- mlir::SmallVector<mlir::Value> bounds,
- mlir::SmallVector<mlir::Value> members, uint64_t mapType,
- mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
- bool isVal = false) {
- if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
- baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
- retTy = baseAddr.getType();
- }
-
- mlir::TypeAttr varType = mlir::TypeAttr::get(
- llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
-
- mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
- loc, retTy, baseAddr, varType, varPtrPtr, members, bounds,
- builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
- builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
- builder.getStringAttr(name));
-
- return op;
-}
-
-bool ClauseProcessor::processMap(
- mlir::Location currentLocation, const llvm::omp::Directive &directive,
- Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands,
- llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
- llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
- const {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- return findRepeatableClause<ClauseTy::Map>(
- [&](const ClauseTy::Map *mapClause,
- const Fortran::parser::CharBlock &source) {
- mlir::Location clauseLocation = converter.genLocation(source);
- const auto &oMapType =
- std::get<std::optional<Fortran::parser::OmpMapType>>(
- mapClause->v.t);
- llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
- // If the map type is specified, then process it else Tofrom is the
- // default.
- if (oMapType) {
- const Fortran::parser::OmpMapType::Type &mapType =
- std::get<Fortran::parser::OmpMapType::Type>(oMapType->t);
- switch (mapType) {
- case Fortran::parser::OmpMapType::Type::To:
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
- break;
- case Fortran::parser::OmpMapType::Type::From:
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
- break;
- case Fortran::parser::OmpMapType::Type::Tofrom:
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
- break;
- case Fortran::parser::OmpMapType::Type::Alloc:
- case Fortran::parser::OmpMapType::Type::Release:
- // alloc and release is the default map_type for the Target Data
- // Ops, i.e. if no bits for map_type is supplied then alloc/release
- // is implicitly assumed based on the target directive. Default
- // value for Target Data and Enter Data is alloc and for Exit Data
- // it is release.
- break;
- case Fortran::parser::OmpMapType::Type::Delete:
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
- }
-
- if (std::get<std::optional<Fortran::parser::OmpMapType::Always>>(
- oMapType->t))
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
- } else {
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
- }
-
- for (const Fortran::parser::OmpObject &ompObject :
- std::get<Fortran::parser::OmpObjectList>(mapClause->v.t).v) {
- llvm::SmallVector<mlir::Value> bounds;
- std::stringstream asFortran;
-
- Fortran::lower::AddrAndBoundsInfo info =
- Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::OmpObject, mlir::omp::DataBoundsOp,
- mlir::omp::DataBoundsType>(
- converter, firOpBuilder, semaCtx, stmtCtx, ompObject,
- clauseLocation, asFortran, bounds, treatIndexAsSection);
-
- auto origSymbol =
- converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
- mlir::Value symAddr = info.addr;
- if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
- symAddr = origSymbol;
-
- // Explicit map captures are captured ByRef by default,
- // optimisation passes may alter this to ByCopy or other capture
- // types to optimise
- mlir::Value mapOp = createMapInfoOp(
- firOpBuilder, clauseLocation, symAddr, mlir::Value{},
- asFortran.str(), bounds, {},
- static_cast<
- std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
- mapTypeBits),
- mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
-
- mapOperands.push_back(mapOp);
- if (mapSymTypes)
- mapSymTypes->push_back(symAddr.getType());
- if (mapSymLocs)
- mapSymLocs->push_back(symAddr.getLoc());
-
- if (mapSymbols)
- mapSymbols->push_back(getOmpObjectSymbol(ompObject));
- }
- });
-}
-
-bool ClauseProcessor::processReduction(
- mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
- const {
- return findRepeatableClause<ClauseTy::Reduction>(
- [&](const ClauseTy::Reduction *reductionClause,
- const Fortran::parser::CharBlock &) {
- ReductionProcessor rp;
- rp.addReductionDecl(currentLocation, converter, reductionClause->v,
- reductionVars, reductionDeclSymbols,
- reductionSymbols);
- });
-}
-
-bool ClauseProcessor::processSectionsReduction(
- mlir::Location currentLocation) const {
- return findRepeatableClause<ClauseTy::Reduction>(
- [&](const ClauseTy::Reduction *, const Fortran::parser::CharBlock &) {
- TODO(currentLocation, "OMPC_Reduction");
- });
-}
-
-bool ClauseProcessor::processTo(
- llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
- return findRepeatableClause<ClauseTy::To>(
- [&](const ClauseTy::To *toClause, const Fortran::parser::CharBlock &) {
- // Case: declare target to(func, var1, var2)...
- gatherFuncAndVarSyms(toClause->v,
- mlir::omp::DeclareTargetCaptureClause::to, result);
- });
-}
-
-bool ClauseProcessor::processEnter(
- llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
- return findRepeatableClause<ClauseTy::Enter>(
- [&](const ClauseTy::Enter *enterClause,
- const Fortran::parser::CharBlock &) {
- // Case: declare target enter(func, var1, var2)...
- gatherFuncAndVarSyms(enterClause->v,
- mlir::omp::DeclareTargetCaptureClause::enter,
- result);
- });
-}
-
-bool ClauseProcessor::processUseDeviceAddr(
- llvm::SmallVectorImpl<mlir::Value> &operands,
- llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
- const {
- return findRepeatableClause<ClauseTy::UseDeviceAddr>(
- [&](const ClauseTy::UseDeviceAddr *devAddrClause,
- const Fortran::parser::CharBlock &) {
- addUseDeviceClause(converter, devAddrClause->v, operands,
- useDeviceTypes, useDeviceLocs, useDeviceSymbols);
- });
-}
-
-bool ClauseProcessor::processUseDevicePtr(
- llvm::SmallVectorImpl<mlir::Value> &operands,
- llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
- const {
- return findRepeatableClause<ClauseTy::UseDevicePtr>(
- [&](const ClauseTy::UseDevicePtr *devPtrClause,
- const Fortran::parser::CharBlock &) {
- addUseDeviceClause(converter, devPtrClause->v, operands, useDeviceTypes,
- useDeviceLocs, useDeviceSymbols);
- });
-}
-
-template <typename T>
-bool ClauseProcessor::processMotionClauses(
- Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
- return findRepeatableClause<T>(
- [&](const T *motionClause, const Fortran::parser::CharBlock &source) {
- mlir::Location clauseLocation = converter.genLocation(source);
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
- static_assert(std::is_same_v<T, ClauseProcessor::ClauseTy::To> ||
- std::is_same_v<T, ClauseProcessor::ClauseTy::From>);
-
- // TODO Support motion modifiers: present, mapper, iterator.
- constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
- std::is_same_v<T, ClauseProcessor::ClauseTy::To>
- ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
- : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
-
- for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) {
- llvm::SmallVector<mlir::Value> bounds;
- std::stringstream asFortran;
- Fortran::lower::AddrAndBoundsInfo info =
- Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::OmpObject, mlir::omp::DataBoundsOp,
- mlir::omp::DataBoundsType>(
- converter, firOpBuilder, semaCtx, stmtCtx, ompObject,
- clauseLocation, asFortran, bounds, treatIndexAsSection);
-
- auto origSymbol =
- converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
- mlir::Value symAddr = info.addr;
- if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
- symAddr = origSymbol;
-
- // Explicit map captures are captured ByRef by default,
- // optimisation passes may alter this to ByCopy or other capture
- // types to optimise
- mlir::Value mapOp = createMapInfoOp(
- firOpBuilder, clauseLocation, symAddr, mlir::Value{},
- asFortran.str(), bounds, {},
- static_cast<
- std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
- mapTypeBits),
- mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
-
- mapOperands.push_back(mapOp);
- }
- });
-}
-
-template <typename... Ts>
-void ClauseProcessor::processTODO(mlir::Location currentLocation,
- llvm::omp::Directive directive) const {
- auto checkUnhandledClause = [&](const auto *x) {
- if (!x)
- return;
- TODO(currentLocation,
- "Unhandled clause " +
- llvm::StringRef(Fortran::parser::ParseTreeDumper::GetNodeName(*x))
- .upper() +
- " in " + llvm::omp::getOpenMPDirectiveName(directive).upper() +
- " construct");
- };
-
- for (ClauseIterator it = clauses.v.begin(); it != clauses.v.end(); ++it)
- (checkUnhandledClause(std::get_if<Ts>(&it->u)), ...);
-}
-
-//===----------------------------------------------------------------------===//
-// Code generation helper functions
-//===----------------------------------------------------------------------===//
-
static fir::GlobalOp globalInitialization(
Fortran::lower::AbstractConverter &converter,
fir::FirOpBuilder &firOpBuilder, const Fortran::semantics::Symbol &sym,
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
new file mode 100644
index 00000000000000..a8b98f3f567249
--- /dev/null
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -0,0 +1,431 @@
+//===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#include "ReductionProcessor.h"
+
+#include "flang/Lower/AbstractConverter.h"
+#include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Parser/tools.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+
+namespace Fortran {
+namespace lower {
+namespace omp {
+
+ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
+ const Fortran::parser::ProcedureDesignator &pd) {
+ auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
+ ReductionProcessor::getRealName(pd).ToString())
+ .Case("max", ReductionIdentifier::MAX)
+ .Case("min", ReductionIdentifier::MIN)
+ .Case("iand", ReductionIdentifier::IAND)
+ .Case("ior", ReductionIdentifier::IOR)
+ .Case("ieor", ReductionIdentifier::IEOR)
+ .Default(std::nullopt);
+ assert(redType && "Invalid Reduction");
+ return *redType;
+}
+
+ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) {
+ switch (intrinsicOp) {
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ return ReductionIdentifier::ADD;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract:
+ return ReductionIdentifier::SUBTRACT;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ return ReductionIdentifier::MULTIPLY;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ return ReductionIdentifier::AND;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ return ReductionIdentifier::EQV;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ return ReductionIdentifier::OR;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ return ReductionIdentifier::NEQV;
+ default:
+ llvm_unreachable("unexpected intrinsic operator in reduction");
+ }
+}
+
+bool ReductionProcessor::supportedIntrinsicProcReduction(
+ const Fortran::parser::ProcedureDesignator &pd) {
+ const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
+ assert(name && "Invalid Reduction Intrinsic.");
+ if (!name->symbol->GetUltimate().attrs().test(
+ Fortran::semantics::Attr::INTRINSIC))
+ return false;
+ auto redType = llvm::StringSwitch<bool>(getRealName(name).ToString())
+ .Case("max", true)
+ .Case("min", true)
+ .Case("iand", true)
+ .Case("ior", true)
+ .Case("ieor", true)
+ .Default(false);
+ return redType;
+}
+
+std::string ReductionProcessor::getReductionName(llvm::StringRef name,
+ mlir::Type ty) {
+ return (llvm::Twine(name) +
+ (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
+ llvm::Twine(ty.getIntOrFloatBitWidth()))
+ .str();
+}
+
+std::string ReductionProcessor::getReductionName(
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+ mlir::Type ty) {
+ std::string reductionName;
+
+ switch (intrinsicOp) {
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ reductionName = "add_reduction";
+ break;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ reductionName = "multiply_reduction";
+ break;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ return "and_reduction";
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ return "eqv_reduction";
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ return "or_reduction";
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ return "neqv_reduction";
+ default:
+ reductionName = "other_reduction";
+ break;
+ }
+
+ return getReductionName(reductionName, ty);
+}
+
+mlir::Value
+ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
+ ReductionIdentifier redId,
+ fir::FirOpBuilder &builder) {
+ assert((fir::isa_integer(type) || fir::isa_real(type) ||
+ type.isa<fir::LogicalType>()) &&
+ "only integer, logical and real types are currently supported");
+ switch (redId) {
+ case ReductionIdentifier::MAX: {
+ if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ return builder.createRealConstant(
+ loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
+ }
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, minInt);
+ }
+ case ReductionIdentifier::MIN: {
+ if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ return builder.createRealConstant(
+ loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
+ }
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, maxInt);
+ }
+ case ReductionIdentifier::IOR: {
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, zeroInt);
+ }
+ case ReductionIdentifier::IEOR: {
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, zeroInt);
+ }
+ case ReductionIdentifier::IAND: {
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, allOnInt);
+ }
+ case ReductionIdentifier::ADD:
+ case ReductionIdentifier::MULTIPLY:
+ case ReductionIdentifier::AND:
+ case ReductionIdentifier::OR:
+ case ReductionIdentifier::EQV:
+ case ReductionIdentifier::NEQV:
+ if (type.isa<mlir::FloatType>())
+ return builder.create<mlir::arith::ConstantOp>(
+ loc, type,
+ builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));
+
+ if (type.isa<fir::LogicalType>()) {
+ mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
+ loc, builder.getI1Type(),
+ builder.getIntegerAttr(builder.getI1Type(),
+ getOperationIdentity(redId, loc)));
+ return builder.createConvert(loc, type, intConst);
+ }
+
+ return builder.create<mlir::arith::ConstantOp>(
+ loc, type,
+ builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
+ case ReductionIdentifier::ID:
+ case ReductionIdentifier::USER_DEF_OP:
+ case ReductionIdentifier::SUBTRACT:
+ TODO(loc, "Reduction of some identifier types is not supported");
+ }
+ llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
+}
+
+mlir::Value ReductionProcessor::createScalarCombiner(
+ fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
+ mlir::Type type, mlir::Value op1, mlir::Value op2) {
+ mlir::Value reductionOp;
+ switch (redId) {
+ case ReductionIdentifier::MAX:
+ reductionOp =
+ getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
+ builder, type, loc, op1, op2);
+ break;
+ case ReductionIdentifier::MIN:
+ reductionOp =
+ getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
+ builder, type, loc, op1, op2);
+ break;
+ case ReductionIdentifier::IOR:
+ assert((type.isIntOrIndex()) && "only integer is expected");
+ reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
+ break;
+ case ReductionIdentifier::IEOR:
+ assert((type.isIntOrIndex()) && "only integer is expected");
+ reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
+ break;
+ case ReductionIdentifier::IAND:
+ assert((type.isIntOrIndex()) && "only integer is expected");
+ reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
+ break;
+ case ReductionIdentifier::ADD:
+ reductionOp =
+ getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
+ builder, type, loc, op1, op2);
+ break;
+ case ReductionIdentifier::MULTIPLY:
+ reductionOp =
+ getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
+ builder, type, loc, op1, op2);
+ break;
+ case ReductionIdentifier::AND: {
+ mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+ mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+
+ mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
+
+ reductionOp = builder.createConvert(loc, type, andiOp);
+ break;
+ }
+ case ReductionIdentifier::OR: {
+ mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+ mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+
+ mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
+
+ reductionOp = builder.createConvert(loc, type, oriOp);
+ break;
+ }
+ case ReductionIdentifier::EQV: {
+ mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+ mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+
+ mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
+
+ reductionOp = builder.createConvert(loc, type, cmpiOp);
+ break;
+ }
+ case ReductionIdentifier::NEQV: {
+ mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+ mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+
+ mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
+
+ reductionOp = builder.createConvert(loc, type, cmpiOp);
+ break;
+ }
+ default:
+ TODO(loc, "Reduction of some intrinsic operators is not supported");
+ }
+
+ return reductionOp;
+}
+
+mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
+ fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+ const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ mlir::ModuleOp module = builder.getModule();
+
+ auto decl =
+ module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+ if (decl)
+ return decl;
+
+ mlir::OpBuilder modBuilder(module.getBodyRegion());
+
+ decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
+ type);
+ builder.createBlock(&decl.getInitializerRegion(),
+ decl.getInitializerRegion().end(), {type}, {loc});
+ builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+ mlir::Value init = getReductionInitValue(loc, type, redId, builder);
+ builder.create<mlir::omp::YieldOp>(loc, init);
+
+ builder.createBlock(&decl.getReductionRegion(),
+ decl.getReductionRegion().end(), {type, type},
+ {loc, loc});
+
+ builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+ mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+ mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+
+ mlir::Value reductionOp =
+ createScalarCombiner(builder, loc, redId, type, op1, op2);
+ builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+
+ return decl;
+}
+
+void ReductionProcessor::addReductionDecl(
+ mlir::Location currentLocation,
+ Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::OmpReductionClause &reduction,
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ *reductionSymbols) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::omp::ReductionDeclareOp decl;
+ const auto &redOperator{
+ std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
+ const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+ if (const auto &redDefinedOp =
+ std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+ const auto &intrinsicOp{
+ std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+ redDefinedOp->u)};
+ ReductionIdentifier redId = getReductionType(intrinsicOp);
+ switch (redId) {
+ case ReductionIdentifier::ADD:
+ case ReductionIdentifier::MULTIPLY:
+ case ReductionIdentifier::AND:
+ case ReductionIdentifier::EQV:
+ case ReductionIdentifier::OR:
+ case ReductionIdentifier::NEQV:
+ break;
+ default:
+ TODO(currentLocation,
+ "Reduction of some intrinsic operators is not supported");
+ break;
+ }
+ for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+ if (const auto *name{
+ Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+ if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ if (reductionSymbols)
+ reductionSymbols->push_back(symbol);
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+ symVal = declOp.getBase();
+ mlir::Type redType =
+ symVal.getType().cast<fir::ReferenceType>().getEleTy();
+ reductionVars.push_back(symVal);
+ if (redType.isa<fir::LogicalType>())
+ decl = createReductionDecl(
+ firOpBuilder,
+ getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
+ redType, currentLocation);
+ else if (redType.isIntOrIndexOrFloat()) {
+ decl = createReductionDecl(firOpBuilder,
+ getReductionName(intrinsicOp, redType),
+ redId, redType, currentLocation);
+ } else {
+ TODO(currentLocation, "Reduction of some types is not supported");
+ }
+ reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+ firOpBuilder.getContext(), decl.getSymName()));
+ }
+ }
+ }
+ } else if (const auto *reductionIntrinsic =
+ std::get_if<Fortran::parser::ProcedureDesignator>(
+ &redOperator.u)) {
+ if (ReductionProcessor::supportedIntrinsicProcReduction(
+ *reductionIntrinsic)) {
+ ReductionProcessor::ReductionIdentifier redId =
+ ReductionProcessor::getReductionType(*reductionIntrinsic);
+ for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+ if (const auto *name{
+ Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+ if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ if (reductionSymbols)
+ reductionSymbols->push_back(symbol);
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+ symVal = declOp.getBase();
+ mlir::Type redType =
+ symVal.getType().cast<fir::ReferenceType>().getEleTy();
+ reductionVars.push_back(symVal);
+ assert(redType.isIntOrIndexOrFloat() &&
+ "Unsupported reduction type");
+ decl = createReductionDecl(
+ firOpBuilder,
+ getReductionName(getRealName(*reductionIntrinsic).ToString(),
+ redType),
+ redId, redType, currentLocation);
+ reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+ firOpBuilder.getContext(), decl.getSymName()));
+ }
+ }
+ }
+ }
+ }
+}
+
+const Fortran::semantics::SourceName
+ReductionProcessor::getRealName(const Fortran::parser::Name *name) {
+ return name->symbol->GetUltimate().name();
+}
+
+const Fortran::semantics::SourceName ReductionProcessor::getRealName(
+ const Fortran::parser::ProcedureDesignator &pd) {
+ const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
+ assert(name && "Invalid Reduction Intrinsic.");
+ return getRealName(name);
+}
+
+int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
+ mlir::Location loc) {
+ switch (redId) {
+ case ReductionIdentifier::ADD:
+ case ReductionIdentifier::OR:
+ case ReductionIdentifier::NEQV:
+ return 0;
+ case ReductionIdentifier::MULTIPLY:
+ case ReductionIdentifier::AND:
+ case ReductionIdentifier::EQV:
+ return 1;
+ default:
+ TODO(loc, "Reduction of some intrinsic operators is not supported");
+ }
+}
+
+} // namespace omp
+} // namespace lower
+} // namespace Fortran
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
new file mode 100644
index 00000000000000..00770fe81d1ef6
--- /dev/null
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -0,0 +1,138 @@
+//===-- Lower/OpenMP/ReductionProcessor.h -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H
+#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
+
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Parser/parse-tree.h"
+#include "flang/Semantics/symbol.h"
+#include "flang/Semantics/type.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace omp {
+class ReductionDeclareOp;
+} // namespace omp
+} // namespace mlir
+
+namespace Fortran {
+namespace lower {
+class AbstractConverter;
+} // namespace lower
+} // namespace Fortran
+
+namespace Fortran {
+namespace lower {
+namespace omp {
+
+class ReductionProcessor {
+public:
+ // TODO: Move this enumeration to the OpenMP dialect
+ enum ReductionIdentifier {
+ ID,
+ USER_DEF_OP,
+ ADD,
+ SUBTRACT,
+ MULTIPLY,
+ AND,
+ OR,
+ EQV,
+ NEQV,
+ MAX,
+ MIN,
+ IAND,
+ IOR,
+ IEOR
+ };
+
+ static ReductionIdentifier
+ getReductionType(const Fortran::parser::ProcedureDesignator &pd);
+
+ static ReductionIdentifier getReductionType(
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp);
+
+ static bool supportedIntrinsicProcReduction(
+ const Fortran::parser::ProcedureDesignator &pd);
+
+ static const Fortran::semantics::SourceName
+ getRealName(const Fortran::parser::Name *name);
+
+ static const Fortran::semantics::SourceName
+ getRealName(const Fortran::parser::ProcedureDesignator &pd);
+
+ static std::string getReductionName(llvm::StringRef name, mlir::Type ty);
+
+ static std::string getReductionName(
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+ mlir::Type ty);
+
+ /// This function returns the identity value of the operator \p
+ /// reductionOpName. For example:
+ /// 0 + x = x,
+ /// 1 * x = x
+ static int getOperationIdentity(ReductionIdentifier redId,
+ mlir::Location loc);
+
+ static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
+ ReductionIdentifier redId,
+ fir::FirOpBuilder &builder);
+
+ template <typename FloatOp, typename IntegerOp>
+ static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+ mlir::Type type, mlir::Location loc,
+ mlir::Value op1, mlir::Value op2);
+
+ static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ ReductionIdentifier redId,
+ mlir::Type type, mlir::Value op1,
+ mlir::Value op2);
+
+ /// Creates an OpenMP reduction declaration and inserts it into the provided
+ /// symbol table. The declaration has a constant initializer with the neutral
+ /// value `initValue`, and the reduction combiner carried over from `reduce`.
+ /// TODO: Generalize this for non-integer types, add atomic region.
+ static mlir::omp::ReductionDeclareOp createReductionDecl(
+ fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+ const ReductionIdentifier redId, mlir::Type type, mlir::Location loc);
+
+ /// Creates a reduction declaration and associates it with an OpenMP block
+ /// directive.
+ static void
+ addReductionDecl(mlir::Location currentLocation,
+ Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::OmpReductionClause &reduction,
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ *reductionSymbols = nullptr);
+};
+
+template <typename FloatOp, typename IntegerOp>
+mlir::Value
+ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
+ mlir::Type type, mlir::Location loc,
+ mlir::Value op1, mlir::Value op2) {
+ assert(type.isIntOrIndexOrFloat() &&
+ "only integer and float types are currently supported");
+ if (type.isIntOrIndex())
+ return builder.create<IntegerOp>(loc, op1, op2);
+ return builder.create<FloatOp>(loc, op1, op2);
+}
+
+} // namespace omp
+} // namespace lower
+} // namespace Fortran
+
+#endif // FORTRAN_LOWER_REDUCTIONPROCESSOR_H
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
new file mode 100644
index 00000000000000..31b15257d18687
--- /dev/null
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -0,0 +1,99 @@
+//===-- Utils..cpp ----------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#include "Utils.h"
+
+#include <flang/Lower/AbstractConverter.h>
+#include <flang/Lower/ConvertType.h>
+#include <flang/Parser/parse-tree.h>
+#include <flang/Parser/tools.h>
+#include <flang/Semantics/tools.h>
+#include <llvm/Support/CommandLine.h>
+
+llvm::cl::opt<bool> treatIndexAsSection(
+ "openmp-treat-index-as-section",
+ llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
+ llvm::cl::init(true));
+
+namespace Fortran {
+namespace lower {
+namespace omp {
+
+void genObjectList(const Fortran::parser::OmpObjectList &objectList,
+ Fortran::lower::AbstractConverter &converter,
+ llvm::SmallVectorImpl<mlir::Value> &operands) {
+ auto addOperands = [&](Fortran::lower::SymbolRef sym) {
+ const mlir::Value variable = converter.getSymbolAddress(sym);
+ if (variable) {
+ operands.push_back(variable);
+ } else {
+ if (const auto *details =
+ sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
+ operands.push_back(converter.getSymbolAddress(details->symbol()));
+ converter.copySymbolBinding(details->symbol(), sym);
+ }
+ }
+ };
+ for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+ Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
+ addOperands(*sym);
+ }
+}
+
+void gatherFuncAndVarSyms(
+ const Fortran::parser::OmpObjectList &objList,
+ mlir::omp::DeclareTargetCaptureClause clause,
+ llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
+ for (const Fortran::parser::OmpObject &ompObject : objList.v) {
+ Fortran::common::visit(
+ Fortran::common::visitors{
+ [&](const Fortran::parser::Designator &designator) {
+ if (const Fortran::parser::Name *name =
+ Fortran::semantics::getDesignatorNameIfDataRef(
+ designator)) {
+ symbolAndClause.emplace_back(clause, *name->symbol);
+ }
+ },
+ [&](const Fortran::parser::Name &name) {
+ symbolAndClause.emplace_back(clause, *name.symbol);
+ }},
+ ompObject.u);
+ }
+}
+
+Fortran::semantics::Symbol *
+getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
+ Fortran::semantics::Symbol *sym = nullptr;
+ std::visit(
+ Fortran::common::visitors{
+ [&](const Fortran::parser::Designator &designator) {
+ if (auto *arrayEle =
+ Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
+ designator)) {
+ sym = GetFirstName(arrayEle->base).symbol;
+ } else if (auto *structComp = Fortran::parser::Unwrap<
+ Fortran::parser::StructureComponent>(designator)) {
+ sym = structComp->component.symbol;
+ } else if (const Fortran::parser::Name *name =
+ Fortran::semantics::getDesignatorNameIfDataRef(
+ designator)) {
+ sym = name->symbol;
+ }
+ },
+ [&](const Fortran::parser::Name &name) { sym = name.symbol; }},
+ ompObject.u);
+ return sym;
+}
+
+} // namespace omp
+} // namespace lower
+} // namespace Fortran
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
new file mode 100644
index 00000000000000..c346f891f0797e
--- /dev/null
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -0,0 +1,68 @@
+//===-- Lower/OpenMP/Utils.h ------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_LOWER_OPENMPUTILS_H
+#define FORTRAN_LOWER_OPENMPUTILS_H
+
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Value.h"
+#include "llvm/Support/CommandLine.h"
+
+extern llvm::cl::opt<bool> treatIndexAsSection;
+
+namespace fir {
+class FirOpBuilder;
+} // namespace fir
+
+namespace Fortran {
+
+namespace semantics {
+class Symbol;
+} // namespace semantics
+
+namespace parser {
+struct OmpObject;
+struct OmpObjectList;
+} // namespace parser
+
+namespace lower {
+
+class AbstractConverter;
+
+namespace omp {
+
+using DeclareTargetCapturePair =
+ std::pair<mlir::omp::DeclareTargetCaptureClause,
+ Fortran::semantics::Symbol>;
+
+mlir::omp::MapInfoOp
+createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
+ mlir::SmallVector<mlir::Value> bounds,
+ mlir::SmallVector<mlir::Value> members, uint64_t mapType,
+ mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
+ bool isVal = false);
+
+void gatherFuncAndVarSyms(
+ const Fortran::parser::OmpObjectList &objList,
+ mlir::omp::DeclareTargetCaptureClause clause,
+ llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause);
+
+Fortran::semantics::Symbol *
+getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject);
+
+void genObjectList(const Fortran::parser::OmpObjectList &objectList,
+ Fortran::lower::AbstractConverter &converter,
+ llvm::SmallVectorImpl<mlir::Value> &operands);
+
+} // namespace omp
+} // namespace lower
+} // namespace Fortran
+
+#endif // FORTRAN_LOWER_OPENMPUTILS_H
More information about the flang-commits
mailing list