[llvm-branch-commits] [flang] [flang][OpenMP] Convert repeatable clauses (except Map) in ClauseProc… (PR #81623)
Krzysztof Parzyszek via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Feb 23 06:01:20 PST 2024
https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/81623
>From 655dce519efb87f8d3babf3b7a5d6132bb82e2a6 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Wed, 21 Feb 2024 15:51:38 -0600
Subject: [PATCH] [flang][OpenMP] Convert repeatable clauses (except Map) in
ClauseProcessor
Rename `findRepeatableClause` to `findRepeatableClause2`, and make the
new `findRepeatableClause` operate on new `omp::Clause` objects.
Leave `Map` unchanged, because it will require more changes for it to
work.
---
flang/include/flang/Evaluate/tools.h | 23 ++
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 218 ++++++++----------
flang/lib/Lower/OpenMP/ClauseProcessor.h | 29 ++-
flang/lib/Lower/OpenMP/Clauses.cpp | 6 -
flang/lib/Lower/OpenMP/Clauses.h | 6 +
flang/lib/Lower/OpenMP/OpenMP.cpp | 182 +++++++--------
flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 155 ++++++-------
flang/lib/Lower/OpenMP/ReductionProcessor.h | 23 +-
flang/lib/Lower/OpenMP/Utils.cpp | 41 ++--
flang/lib/Lower/OpenMP/Utils.h | 10 +-
10 files changed, 348 insertions(+), 345 deletions(-)
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index d257da1a709642..e9999974944e88 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -430,6 +430,29 @@ template <typename A> std::optional<CoarrayRef> ExtractCoarrayRef(const A &x) {
}
}
+struct ExtractSubstringHelper {
+ template <typename T> static std::optional<Substring> visit(T &&) {
+ return std::nullopt;
+ }
+
+ static std::optional<Substring> visit(const Substring &e) { return e; }
+
+ template <typename T>
+ static std::optional<Substring> visit(const Designator<T> &e) {
+ return std::visit([](auto &&s) { return visit(s); }, e.u);
+ }
+
+ template <typename T>
+ static std::optional<Substring> visit(const Expr<T> &e) {
+ return std::visit([](auto &&s) { return visit(s); }, e.u);
+ }
+};
+
+template <typename A>
+std::optional<Substring> ExtractSubstring(const A &x) {
+ return ExtractSubstringHelper::visit(x);
+}
+
// If an expression is simply a whole symbol data designator,
// extract and return that symbol, else null.
template <typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) {
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 9987cd73fc7670..6e45a939333d62 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -87,7 +87,7 @@ getSimdModifier(const omp::clause::Schedule &clause) {
static void
genAllocateClause(Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpAllocateClause &ompAllocateClause,
+ const omp::clause::Allocate &clause,
llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
llvm::SmallVectorImpl<mlir::Value> &allocateOperands) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -95,21 +95,18 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext stmtCtx;
mlir::Value allocatorOperand;
- const Fortran::parser::OmpObjectList &ompObjectList =
- std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t);
- const auto &allocateModifier = std::get<
- std::optional<Fortran::parser::OmpAllocateClause::AllocateModifier>>(
- ompAllocateClause.t);
+ const omp::ObjectList &objectList = std::get<omp::ObjectList>(clause.t);
+ const auto &modifier =
+ std::get<std::optional<omp::clause::Allocate::Modifier>>(clause.t);
// If the allocate modifier is present, check if we only use the allocator
// submodifier. ALIGN in this context is unimplemented
const bool onlyAllocator =
- allocateModifier &&
- std::holds_alternative<
- Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
- allocateModifier->u);
+ modifier &&
+ std::holds_alternative<omp::clause::Allocate::Modifier::Allocator>(
+ modifier->u);
- if (allocateModifier && !onlyAllocator) {
+ if (modifier && !onlyAllocator) {
TODO(currentLocation, "OmpAllocateClause ALIGN modifier");
}
@@ -117,20 +114,17 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter,
// to list of allocators, otherwise, add default allocator to
// list of allocators.
if (onlyAllocator) {
- const auto &allocatorValue = std::get<
- Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
- allocateModifier->u);
- allocatorOperand = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(allocatorValue.v), stmtCtx));
- allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
- allocatorOperand);
+ const auto &value =
+ std::get<omp::clause::Allocate::Modifier::Allocator>(modifier->u);
+ mlir::Value operand =
+ fir::getBase(converter.genExprValue(value.v, stmtCtx));
+ allocatorOperands.append(objectList.size(), operand);
} else {
- allocatorOperand = firOpBuilder.createIntegerConstant(
+ mlir::Value operand = firOpBuilder.createIntegerConstant(
currentLocation, firOpBuilder.getI32Type(), 1);
- allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
- allocatorOperand);
+ allocatorOperands.append(objectList.size(), operand);
}
- genObjectList(ompObjectList, converter, allocateOperands);
+ genObjectList(objectList, converter, allocateOperands);
}
static mlir::omp::ClauseProcBindKindAttr
@@ -157,20 +151,17 @@ genProcBindKindAttr(fir::FirOpBuilder &firOpBuilder,
static mlir::omp::ClauseTaskDependAttr
genDependKindAttr(fir::FirOpBuilder &firOpBuilder,
- const Fortran::parser::OmpClause::Depend *dependClause) {
+ const omp::clause::Depend &clause) {
mlir::omp::ClauseTaskDepend pbKind;
- switch (
- std::get<Fortran::parser::OmpDependenceType>(
- std::get<Fortran::parser::OmpDependClause::InOut>(dependClause->v.u)
- .t)
- .v) {
- case Fortran::parser::OmpDependenceType::Type::In:
+ const auto &inOut = std::get<omp::clause::Depend::InOut>(clause.u);
+ switch (std::get<omp::clause::Depend::Type>(inOut.t)) {
+ case omp::clause::Depend::Type::In:
pbKind = mlir::omp::ClauseTaskDepend::taskdependin;
break;
- case Fortran::parser::OmpDependenceType::Type::Out:
+ case omp::clause::Depend::Type::Out:
pbKind = mlir::omp::ClauseTaskDepend::taskdependout;
break;
- case Fortran::parser::OmpDependenceType::Type::Inout:
+ case omp::clause::Depend::Type::Inout:
pbKind = mlir::omp::ClauseTaskDepend::taskdependinout;
break;
default:
@@ -181,45 +172,41 @@ genDependKindAttr(fir::FirOpBuilder &firOpBuilder,
pbKind);
}
-static mlir::Value getIfClauseOperand(
- Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpClause::If *ifClause,
- Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
- mlir::Location clauseLocation) {
+static mlir::Value
+getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
+ const omp::clause::If &clause,
+ omp::clause::If::DirectiveNameModifier directiveName,
+ mlir::Location clauseLocation) {
// Only consider the clause if it's intended for the given directive.
- auto &directive = std::get<
- std::optional<Fortran::parser::OmpIfClause::DirectiveNameModifier>>(
- ifClause->v.t);
+ auto &directive =
+ std::get<std::optional<omp::clause::If::DirectiveNameModifier>>(clause.t);
if (directive && directive.value() != directiveName)
return nullptr;
Fortran::lower::StatementContext stmtCtx;
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
mlir::Value ifVal = fir::getBase(
- converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
+ converter.genExprValue(std::get<omp::SomeExpr>(clause.t), stmtCtx));
return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(),
ifVal);
}
static void
addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpObjectList &useDeviceClause,
+ const omp::ObjectList &objects,
llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
&useDeviceSymbols) {
- genObjectList(useDeviceClause, converter, operands);
+ genObjectList(objects, converter, operands);
for (mlir::Value &operand : operands) {
checkMapType(operand.getLoc(), operand.getType());
useDeviceTypes.push_back(operand.getType());
useDeviceLocs.push_back(operand.getLoc());
}
- for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) {
- Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
- useDeviceSymbols.push_back(sym);
- }
+ for (const omp::Object &object : objects)
+ useDeviceSymbols.push_back(object.id());
}
//===----------------------------------------------------------------------===//
@@ -527,10 +514,10 @@ bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
bool ClauseProcessor::processAllocate(
llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const {
- return findRepeatableClause<ClauseTy::Allocate>(
- [&](const ClauseTy::Allocate *allocateClause,
+ return findRepeatableClause<omp::clause::Allocate>(
+ [&](const omp::clause::Allocate &clause,
const Fortran::parser::CharBlock &) {
- genAllocateClause(converter, allocateClause->v, allocatorOperands,
+ genAllocateClause(converter, clause, allocatorOperands,
allocateOperands);
});
}
@@ -547,12 +534,12 @@ bool ClauseProcessor::processCopyin() const {
if (converter.isPresentShallowLookup(*sym))
converter.copyHostAssociateVar(*sym, copyAssignIP);
};
- bool hasCopyin = findRepeatableClause<ClauseTy::Copyin>(
- [&](const ClauseTy::Copyin *copyinClause,
+ bool hasCopyin = findRepeatableClause<omp::clause::Copyin>(
+ [&](const omp::clause::Copyin &clause,
const Fortran::parser::CharBlock &) {
- const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v;
- for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
- Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
+ for (const omp::Object &object : clause.v) {
+ Fortran::semantics::Symbol *sym = object.id();
+ assert(sym && "Expecting symbol");
if (const auto *commonDetails =
sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
for (const auto &mem : commonDetails->objects())
@@ -716,13 +703,11 @@ bool ClauseProcessor::processCopyPrivate(
copyPrivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
};
- bool hasCopyPrivate = findRepeatableClause<ClauseTy::Copyprivate>(
- [&](const ClauseTy::Copyprivate *copyPrivateClause,
+ bool hasCopyPrivate = findRepeatableClause<clause::Copyprivate>(
+ [&](const clause::Copyprivate &clause,
const Fortran::parser::CharBlock &) {
- const Fortran::parser::OmpObjectList &ompObjectList =
- copyPrivateClause->v;
- for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
- Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
+ for (const Object &object : clause.v) {
+ Fortran::semantics::Symbol *sym = object.id();
if (const auto *commonDetails =
sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
for (const auto &mem : commonDetails->objects())
@@ -741,38 +726,30 @@ bool ClauseProcessor::processDepend(
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- return findRepeatableClause<ClauseTy::Depend>(
- [&](const ClauseTy::Depend *dependClause,
+ return findRepeatableClause<omp::clause::Depend>(
+ [&](const omp::clause::Depend &clause,
const Fortran::parser::CharBlock &) {
- const std::list<Fortran::parser::Designator> &depVal =
- std::get<std::list<Fortran::parser::Designator>>(
- std::get<Fortran::parser::OmpDependClause::InOut>(
- dependClause->v.u)
- .t);
+ assert(std::holds_alternative<omp::clause::Depend::InOut>(clause.u) &&
+ "Only InOut is handled at the moment");
+ const auto &inOut = std::get<omp::clause::Depend::InOut>(clause.u);
+ const auto &objects = std::get<omp::ObjectList>(inOut.t);
+
mlir::omp::ClauseTaskDependAttr dependTypeOperand =
- genDependKindAttr(firOpBuilder, dependClause);
- dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(),
- dependTypeOperand);
- for (const Fortran::parser::Designator &ompObject : depVal) {
- Fortran::semantics::Symbol *sym = nullptr;
- std::visit(
- Fortran::common::visitors{
- [&](const Fortran::parser::DataRef &designator) {
- if (const Fortran::parser::Name *name =
- std::get_if<Fortran::parser::Name>(&designator.u)) {
- sym = name->symbol;
- } else if (std::get_if<Fortran::common::Indirection<
- Fortran::parser::ArrayElement>>(
- &designator.u)) {
- TODO(converter.getCurrentLocation(),
- "array sections not supported for task depend");
- }
- },
- [&](const Fortran::parser::Substring &designator) {
- TODO(converter.getCurrentLocation(),
- "substring not supported for task depend");
- }},
- (ompObject).u);
+ genDependKindAttr(firOpBuilder, clause);
+ dependTypeOperands.append(objects.size(), dependTypeOperand);
+
+ for (const omp::Object &object : objects) {
+ assert(object.ref() && "Expecting designator");
+
+ if (Fortran::evaluate::ExtractSubstring(*object.ref())) {
+ TODO(converter.getCurrentLocation(),
+ "substring not supported for task depend");
+ } else if (Fortran::evaluate::IsArrayElement(*object.ref())) {
+ TODO(converter.getCurrentLocation(),
+ "array sections not supported for task depend");
+ }
+
+ Fortran::semantics::Symbol *sym = object.id();
const mlir::Value variable = converter.getSymbolAddress(*sym);
dependOperands.push_back(variable);
}
@@ -780,14 +757,14 @@ bool ClauseProcessor::processDepend(
}
bool ClauseProcessor::processIf(
- Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
+ omp::clause::If::DirectiveNameModifier directiveName,
mlir::Value &result) const {
bool found = false;
- findRepeatableClause<ClauseTy::If>(
- [&](const ClauseTy::If *ifClause,
+ findRepeatableClause<omp::clause::If>(
+ [&](const omp::clause::If &clause,
const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
- mlir::Value operand = getIfClauseOperand(converter, ifClause,
+ mlir::Value operand = getIfClauseOperand(converter, clause,
directiveName, clauseLocation);
// Assume that, at most, a single 'if' clause will be applicable to the
// given directive.
@@ -801,12 +778,11 @@ bool ClauseProcessor::processIf(
bool ClauseProcessor::processLink(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
- return findRepeatableClause<ClauseTy::Link>(
- [&](const ClauseTy::Link *linkClause,
- const Fortran::parser::CharBlock &) {
+ return findRepeatableClause<omp::clause::Link>(
+ [&](const omp::clause::Link &clause, const Fortran::parser::CharBlock &) {
// Case: declare target link(var1, var2)...
gatherFuncAndVarSyms(
- linkClause->v, mlir::omp::DeclareTargetCaptureClause::link, result);
+ clause.v, mlir::omp::DeclareTargetCaptureClause::link, result);
});
}
@@ -843,7 +819,7 @@ bool ClauseProcessor::processMap(
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- return findRepeatableClause<ClauseTy::Map>(
+ return findRepeatableClause2<ClauseTy::Map>(
[&](const ClauseTy::Map *mapClause,
const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
@@ -935,43 +911,41 @@ bool ClauseProcessor::processReduction(
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
const {
- return findRepeatableClause<ClauseTy::Reduction>(
- [&](const ClauseTy::Reduction *reductionClause,
+ return findRepeatableClause<omp::clause::Reduction>(
+ [&](const omp::clause::Reduction &clause,
const Fortran::parser::CharBlock &) {
ReductionProcessor rp;
- rp.addReductionDecl(currentLocation, converter, reductionClause->v,
- reductionVars, reductionDeclSymbols,
- reductionSymbols);
+ rp.addReductionDecl(currentLocation, converter, clause, reductionVars,
+ reductionDeclSymbols, reductionSymbols);
});
}
bool ClauseProcessor::processSectionsReduction(
mlir::Location currentLocation) const {
- return findRepeatableClause<ClauseTy::Reduction>(
- [&](const ClauseTy::Reduction *, const Fortran::parser::CharBlock &) {
+ return findRepeatableClause<omp::clause::Reduction>(
+ [&](const omp::clause::Reduction &, const Fortran::parser::CharBlock &) {
TODO(currentLocation, "OMPC_Reduction");
});
}
bool ClauseProcessor::processTo(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
- return findRepeatableClause<ClauseTy::To>(
- [&](const ClauseTy::To *toClause, const Fortran::parser::CharBlock &) {
+ return findRepeatableClause<omp::clause::To>(
+ [&](const omp::clause::To &clause, const Fortran::parser::CharBlock &) {
// Case: declare target to(func, var1, var2)...
- gatherFuncAndVarSyms(toClause->v,
+ gatherFuncAndVarSyms(clause.v,
mlir::omp::DeclareTargetCaptureClause::to, result);
});
}
bool ClauseProcessor::processEnter(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
- return findRepeatableClause<ClauseTy::Enter>(
- [&](const ClauseTy::Enter *enterClause,
+ return findRepeatableClause<omp::clause::Enter>(
+ [&](const omp::clause::Enter &clause,
const Fortran::parser::CharBlock &) {
// Case: declare target enter(func, var1, var2)...
- gatherFuncAndVarSyms(enterClause->v,
- mlir::omp::DeclareTargetCaptureClause::enter,
- result);
+ gatherFuncAndVarSyms(
+ clause.v, mlir::omp::DeclareTargetCaptureClause::enter, result);
});
}
@@ -981,11 +955,11 @@ bool ClauseProcessor::processUseDeviceAddr(
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
const {
- return findRepeatableClause<ClauseTy::UseDeviceAddr>(
- [&](const ClauseTy::UseDeviceAddr *devAddrClause,
+ return findRepeatableClause<omp::clause::UseDeviceAddr>(
+ [&](const omp::clause::UseDeviceAddr &clause,
const Fortran::parser::CharBlock &) {
- addUseDeviceClause(converter, devAddrClause->v, operands,
- useDeviceTypes, useDeviceLocs, useDeviceSymbols);
+ addUseDeviceClause(converter, clause.v, operands, useDeviceTypes,
+ useDeviceLocs, useDeviceSymbols);
});
}
@@ -995,10 +969,10 @@ bool ClauseProcessor::processUseDevicePtr(
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
const {
- return findRepeatableClause<ClauseTy::UseDevicePtr>(
- [&](const ClauseTy::UseDevicePtr *devPtrClause,
+ return findRepeatableClause<omp::clause::UseDevicePtr>(
+ [&](const omp::clause::UseDevicePtr &clause,
const Fortran::parser::CharBlock &) {
- addUseDeviceClause(converter, devPtrClause->v, operands, useDeviceTypes,
+ addUseDeviceClause(converter, clause.v, operands, useDeviceTypes,
useDeviceLocs, useDeviceSymbols);
});
}
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index c87fc30c88bb93..3f6adcce8ae877 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -105,9 +105,8 @@ class ClauseProcessor {
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
- bool
- processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
- mlir::Value &result) const;
+ bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
+ mlir::Value &result) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
@@ -178,6 +177,10 @@ class ClauseProcessor {
/// if at least one instance was found.
template <typename T>
bool findRepeatableClause(
+ std::function<void(const T &, const Fortran::parser::CharBlock &source)>
+ callbackFn) const;
+ template <typename T>
+ bool findRepeatableClause2(
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
callbackFn) const;
@@ -195,7 +198,7 @@ template <typename T>
bool ClauseProcessor::processMotionClauses(
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
- return findRepeatableClause<T>(
+ return findRepeatableClause2<T>(
[&](const T *motionClause, const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -295,6 +298,24 @@ const T *ClauseProcessor::findUniqueClause(
template <typename T>
bool ClauseProcessor::findRepeatableClause(
+ std::function<void(const T &, const Fortran::parser::CharBlock &source)>
+ callbackFn) const {
+ bool found = false;
+ ClauseIterator nextIt, endIt = clauses.end();
+ for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
+ nextIt = findClause<T>(it, endIt);
+
+ if (nextIt != endIt) {
+ callbackFn(std::get<T>(nextIt->u), nextIt->source);
+ found = true;
+ ++nextIt;
+ }
+ }
+ return found;
+}
+
+template <typename T>
+bool ClauseProcessor::findRepeatableClause2(
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
callbackFn) const {
bool found = false;
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index 0b90b705b9e406..a3aa3d4de3cdc9 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -205,12 +205,6 @@ namespace clause {
#undef EMPTY_CLASS
#undef WRAPPER_CLASS
-using DefinedOperator = tomp::clause::DefinedOperatorT<SymIdent, SymReference>;
-using ProcedureDesignator =
- tomp::clause::ProcedureDesignatorT<SymIdent, SymReference>;
-using ReductionOperator =
- tomp::clause::ReductionOperatorT<SymIdent, SymReference>;
-
DefinedOperator makeDefOp(const parser::DefinedOperator &inp,
semantics::SemanticsContext &semaCtx) {
return DefinedOperator{
diff --git a/flang/lib/Lower/OpenMP/Clauses.h b/flang/lib/Lower/OpenMP/Clauses.h
index a7e563f4b0f90b..c167e34637d500 100644
--- a/flang/lib/Lower/OpenMP/Clauses.h
+++ b/flang/lib/Lower/OpenMP/Clauses.h
@@ -106,6 +106,12 @@ getBaseObject(const Object &object,
Fortran::semantics::SemanticsContext &semaCtx);
namespace clause {
+using DefinedOperator = tomp::clause::DefinedOperatorT<SymIdent, SymReference>;
+using ProcedureDesignator =
+ tomp::clause::ProcedureDesignatorT<SymIdent, SymReference>;
+using ReductionOperator =
+ tomp::clause::ReductionOperatorT<SymIdent, SymReference>;
+
#ifdef EMPTY_CLASS
#undef EMPTY_CLASS
#endif
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7953bf83cba0fe..7445c0f13526f7 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -572,8 +572,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
- ifClauseOperand);
+ cp.processIf(clause::If::DirectiveNameModifier::Parallel, ifClauseOperand);
cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
cp.processProcBind(procBindKindAttr);
cp.processDefault();
@@ -676,8 +675,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
dependOperands;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Task,
- ifClauseOperand);
+ cp.processIf(clause::If::DirectiveNameModifier::Task, ifClauseOperand);
cp.processAllocate(allocatorOperands, allocateOperands);
cp.processDefault();
cp.processFinal(stmtCtx, finalClauseOperand);
@@ -738,7 +736,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData,
+ cp.processIf(clause::If::DirectiveNameModifier::TargetData,
ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
@@ -770,19 +768,16 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
llvm::SmallVector<mlir::Attribute> dependTypeOperands;
- Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
+ clause::If::DirectiveNameModifier directiveName;
llvm::omp::Directive directive;
if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) {
- directiveName =
- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData;
+ directiveName = clause::If::DirectiveNameModifier::TargetEnterData;
directive = llvm::omp::Directive::OMPD_target_enter_data;
} else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) {
- directiveName =
- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
+ directiveName = clause::If::DirectiveNameModifier::TargetExitData;
directive = llvm::omp::Directive::OMPD_target_exit_data;
} else if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
- directiveName =
- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate;
+ directiveName = clause::If::DirectiveNameModifier::TargetUpdate;
directive = llvm::omp::Directive::OMPD_target_update;
} else {
return nullptr;
@@ -984,8 +979,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
- ifClauseOperand);
+ cp.processIf(clause::If::DirectiveNameModifier::Target, ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processThreadLimit(stmtCtx, threadLimitOperand);
cp.processDepend(dependTypeOperands, dependOperands);
@@ -1102,8 +1096,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
- ifClauseOperand);
+ cp.processIf(clause::If::DirectiveNameModifier::Teams, ifClauseOperand);
cp.processAllocate(allocatorOperands, allocateOperands);
cp.processDefault();
cp.processNumTeams(stmtCtx, numTeamsClauseOperand);
@@ -1142,8 +1135,9 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
if (const auto *objectList{
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
+ ObjectList objects{makeList(*objectList, semaCtx)};
// Case: declare target(func, var1, var2)
- gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
+ gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to,
symbolAndClause);
} else if (const auto *clauseList{
Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
@@ -1257,7 +1251,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,
if (const auto &ompObjectList =
std::get<std::optional<Fortran::parser::OmpObjectList>>(
flushConstruct.t))
- genObjectList(*ompObjectList, converter, operandRange);
+ genObjectList2(*ompObjectList, converter, operandRange);
const auto &memOrderClause =
std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
flushConstruct.t);
@@ -1419,8 +1413,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
loopVarTypeSize);
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
cp.processReduction(loc, reductionVars, reductionDeclSymbols);
- cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
- ifClauseOperand);
+ cp.processIf(clause::If::DirectiveNameModifier::Simd, ifClauseOperand);
cp.processSimdlen(simdlenClauseOperand);
cp.processSafelen(safelenClauseOperand);
cp.processTODO<Fortran::parser::OmpClause::Aligned,
@@ -2223,106 +2216,99 @@ void Fortran::lower::genOpenMPReduction(
const Fortran::parser::OmpClauseList &clauseList) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- for (const Fortran::parser::OmpClause &clause : clauseList.v) {
+ List<Clause> clauses{makeList(clauseList, semaCtx)};
+
+ for (const Clause &clause : clauses) {
if (const auto &reductionClause =
- std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
- const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
- reductionClause->v.t)};
- const auto &objectList{
- std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)};
+ std::get_if<clause::Reduction>(&clause.u)) {
+ const auto &redOperator{
+ std::get<clause::ReductionOperator>(reductionClause->t)};
+ const auto &objects{std::get<ObjectList>(reductionClause->t)};
if (const auto *reductionOp =
- std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+ std::get_if<clause::DefinedOperator>(&redOperator.u)) {
const auto &intrinsicOp{
- std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+ std::get<clause::DefinedOperator::IntrinsicOperator>(
reductionOp->u)};
switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ case clause::DefinedOperator::IntrinsicOperator::Add:
+ case clause::DefinedOperator::IntrinsicOperator::Multiply:
+ case clause::DefinedOperator::IntrinsicOperator::AND:
+ case clause::DefinedOperator::IntrinsicOperator::EQV:
+ case clause::DefinedOperator::IntrinsicOperator::OR:
+ case clause::DefinedOperator::IntrinsicOperator::NEQV:
break;
default:
continue;
}
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
- reductionVal = declOp.getBase();
- mlir::Type reductionType =
- reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
- if (!reductionType.isa<fir::LogicalType>()) {
- if (!reductionType.isIntOrIndexOrFloat())
- continue;
- }
- for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
- if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
- reductionValUse.getOwner())) {
- mlir::Value loadVal = loadOp.getRes();
- if (reductionType.isa<fir::LogicalType>()) {
- mlir::Operation *reductionOp = findReductionChain(loadVal);
- fir::ConvertOp convertOp =
- getConvertFromReductionOp(reductionOp, loadVal);
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal, &convertOp);
- removeStoreOp(reductionOp, reductionVal);
- } else if (mlir::Operation *reductionOp =
- findReductionChain(loadVal, &reductionVal)) {
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal);
- }
+ for (const Object &object : objects) {
+ if (const Fortran::semantics::Symbol *symbol = object.id()) {
+ mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+ reductionVal = declOp.getBase();
+ mlir::Type reductionType =
+ reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
+ if (!reductionType.isa<fir::LogicalType>()) {
+ if (!reductionType.isIntOrIndexOrFloat())
+ continue;
+ }
+ for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
+ if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+ mlir::Value loadVal = loadOp.getRes();
+ if (reductionType.isa<fir::LogicalType>()) {
+ mlir::Operation *reductionOp = findReductionChain(loadVal);
+ fir::ConvertOp convertOp =
+ getConvertFromReductionOp(reductionOp, loadVal);
+ updateReduction(reductionOp, firOpBuilder, loadVal,
+ reductionVal, &convertOp);
+ removeStoreOp(reductionOp, reductionVal);
+ } else if (mlir::Operation *reductionOp =
+ findReductionChain(loadVal, &reductionVal)) {
+ updateReduction(reductionOp, firOpBuilder, loadVal,
+ reductionVal);
}
}
}
}
}
} else if (const auto *reductionIntrinsic =
- std::get_if<Fortran::parser::ProcedureDesignator>(
+ std::get_if<clause::ProcedureDesignator>(
&redOperator.u)) {
if (!ReductionProcessor::supportedIntrinsicProcReduction(
*reductionIntrinsic))
continue;
ReductionProcessor::ReductionIdentifier redId =
ReductionProcessor::getReductionType(*reductionIntrinsic);
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
- reductionVal = declOp.getBase();
- for (const mlir::OpOperand &reductionValUse :
- reductionVal.getUses()) {
- if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
- reductionValUse.getOwner())) {
- mlir::Value loadVal = loadOp.getRes();
- // Max is lowered as a compare -> select.
- // Match the pattern here.
- mlir::Operation *reductionOp =
- findReductionChain(loadVal, &reductionVal);
- if (reductionOp == nullptr)
- continue;
-
- if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
- redId == ReductionProcessor::ReductionIdentifier::MIN) {
- assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
- "Selection Op not found in reduction intrinsic");
- mlir::Operation *compareOp =
- getCompareFromReductionOp(reductionOp, loadVal);
- updateReduction(compareOp, firOpBuilder, loadVal,
- reductionVal);
- }
- if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
- redId == ReductionProcessor::ReductionIdentifier::IEOR ||
- redId == ReductionProcessor::ReductionIdentifier::IAND) {
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal);
- }
+ for (const Object &object : objects) {
+ if (const Fortran::semantics::Symbol *symbol = object.id()) {
+ mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+ reductionVal = declOp.getBase();
+ for (const mlir::OpOperand &reductionValUse :
+ reductionVal.getUses()) {
+ if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+ mlir::Value loadVal = loadOp.getRes();
+ // Max is lowered as a compare -> select.
+ // Match the pattern here.
+ mlir::Operation *reductionOp =
+ findReductionChain(loadVal, &reductionVal);
+ if (reductionOp == nullptr)
+ continue;
+
+ if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
+ redId == ReductionProcessor::ReductionIdentifier::MIN) {
+ assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
+ "Selection Op not found in reduction intrinsic");
+ mlir::Operation *compareOp =
+ getCompareFromReductionOp(reductionOp, loadVal);
+ updateReduction(compareOp, firOpBuilder, loadVal,
+ reductionVal);
+ }
+ if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
+ redId == ReductionProcessor::ReductionIdentifier::IEOR ||
+ redId == ReductionProcessor::ReductionIdentifier::IAND) {
+ updateReduction(reductionOp, firOpBuilder, loadVal,
+ reductionVal);
}
}
}
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index a8b98f3f567249..bf755b27487d95 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -23,9 +23,9 @@ namespace lower {
namespace omp {
ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
- const Fortran::parser::ProcedureDesignator &pd) {
+ const omp::clause::ProcedureDesignator &pd) {
auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
- ReductionProcessor::getRealName(pd).ToString())
+ getRealName(pd.v.id()).ToString())
.Case("max", ReductionIdentifier::MAX)
.Case("min", ReductionIdentifier::MIN)
.Case("iand", ReductionIdentifier::IAND)
@@ -37,21 +37,21 @@ ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
}
ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) {
+ omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Add:
return ReductionIdentifier::ADD;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
return ReductionIdentifier::SUBTRACT;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
return ReductionIdentifier::MULTIPLY;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ case omp::clause::DefinedOperator::IntrinsicOperator::AND:
return ReductionIdentifier::AND;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
return ReductionIdentifier::EQV;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ case omp::clause::DefinedOperator::IntrinsicOperator::OR:
return ReductionIdentifier::OR;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
return ReductionIdentifier::NEQV;
default:
llvm_unreachable("unexpected intrinsic operator in reduction");
@@ -59,13 +59,11 @@ ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
}
bool ReductionProcessor::supportedIntrinsicProcReduction(
- const Fortran::parser::ProcedureDesignator &pd) {
- const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
- assert(name && "Invalid Reduction Intrinsic.");
- if (!name->symbol->GetUltimate().attrs().test(
- Fortran::semantics::Attr::INTRINSIC))
+ const omp::clause::ProcedureDesignator &pd) {
+ Fortran::semantics::Symbol *sym = pd.v.id();
+ if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC))
return false;
- auto redType = llvm::StringSwitch<bool>(getRealName(name).ToString())
+ auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
.Case("max", true)
.Case("min", true)
.Case("iand", true)
@@ -84,24 +82,24 @@ std::string ReductionProcessor::getReductionName(llvm::StringRef name,
}
std::string ReductionProcessor::getReductionName(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+ omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
mlir::Type ty) {
std::string reductionName;
switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Add:
reductionName = "add_reduction";
break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
reductionName = "multiply_reduction";
break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ case omp::clause::DefinedOperator::IntrinsicOperator::AND:
return "and_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
return "eqv_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ case omp::clause::DefinedOperator::IntrinsicOperator::OR:
return "or_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
return "neqv_reduction";
default:
reductionName = "other_reduction";
@@ -305,7 +303,7 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
void ReductionProcessor::addReductionDecl(
mlir::Location currentLocation,
Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpReductionClause &reduction,
+ const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
@@ -313,12 +311,12 @@ void ReductionProcessor::addReductionDecl(
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::omp::ReductionDeclareOp decl;
const auto &redOperator{
- std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
- const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+ std::get<omp::clause::ReductionOperator>(reduction.t)};
+ const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
if (const auto &redDefinedOp =
- std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+ std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
const auto &intrinsicOp{
- std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+ std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
redDefinedOp->u)};
ReductionIdentifier redId = getReductionType(intrinsicOp);
switch (redId) {
@@ -334,10 +332,41 @@ void ReductionProcessor::addReductionDecl(
"Reduction of some intrinsic operators is not supported");
break;
}
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ for (const omp::Object &object : objectList) {
+ if (const Fortran::semantics::Symbol *symbol = object.id()) {
+ if (reductionSymbols)
+ reductionSymbols->push_back(symbol);
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+ symVal = declOp.getBase();
+ mlir::Type redType =
+ symVal.getType().cast<fir::ReferenceType>().getEleTy();
+ reductionVars.push_back(symVal);
+ if (redType.isa<fir::LogicalType>())
+ decl = createReductionDecl(
+ firOpBuilder,
+ getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
+ redType, currentLocation);
+ else if (redType.isIntOrIndexOrFloat()) {
+ decl = createReductionDecl(firOpBuilder,
+ getReductionName(intrinsicOp, redType),
+ redId, redType, currentLocation);
+ } else {
+ TODO(currentLocation, "Reduction of some types is not supported");
+ }
+ reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+ firOpBuilder.getContext(), decl.getSymName()));
+ }
+ }
+ } else if (const auto *reductionIntrinsic =
+ std::get_if<omp::clause::ProcedureDesignator>(
+ &redOperator.u)) {
+ if (ReductionProcessor::supportedIntrinsicProcReduction(
+ *reductionIntrinsic)) {
+ ReductionProcessor::ReductionIdentifier redId =
+ ReductionProcessor::getReductionType(*reductionIntrinsic);
+ for (const omp::Object &object : objectList) {
+ if (const Fortran::semantics::Symbol *symbol = object.id()) {
if (reductionSymbols)
reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
@@ -346,68 +375,28 @@ void ReductionProcessor::addReductionDecl(
mlir::Type redType =
symVal.getType().cast<fir::ReferenceType>().getEleTy();
reductionVars.push_back(symVal);
- if (redType.isa<fir::LogicalType>())
- decl = createReductionDecl(
- firOpBuilder,
- getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
- redType, currentLocation);
- else if (redType.isIntOrIndexOrFloat()) {
- decl = createReductionDecl(firOpBuilder,
- getReductionName(intrinsicOp, redType),
- redId, redType, currentLocation);
- } else {
- TODO(currentLocation, "Reduction of some types is not supported");
- }
+ assert(redType.isIntOrIndexOrFloat() && "Unsupported reduction type");
+ decl = createReductionDecl(
+ firOpBuilder,
+ getReductionName(getRealName(*reductionIntrinsic).ToString(),
+ redType),
+ redId, redType, currentLocation);
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
firOpBuilder.getContext(), decl.getSymName()));
}
}
}
- } else if (const auto *reductionIntrinsic =
- std::get_if<Fortran::parser::ProcedureDesignator>(
- &redOperator.u)) {
- if (ReductionProcessor::supportedIntrinsicProcReduction(
- *reductionIntrinsic)) {
- ReductionProcessor::ReductionIdentifier redId =
- ReductionProcessor::getReductionType(*reductionIntrinsic);
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- if (reductionSymbols)
- reductionSymbols->push_back(symbol);
- mlir::Value symVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
- symVal = declOp.getBase();
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- assert(redType.isIntOrIndexOrFloat() &&
- "Unsupported reduction type");
- decl = createReductionDecl(
- firOpBuilder,
- getReductionName(getRealName(*reductionIntrinsic).ToString(),
- redType),
- redId, redType, currentLocation);
- reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
- firOpBuilder.getContext(), decl.getSymName()));
- }
- }
- }
- }
}
}
const Fortran::semantics::SourceName
-ReductionProcessor::getRealName(const Fortran::parser::Name *name) {
- return name->symbol->GetUltimate().name();
+ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) {
+ return symbol->GetUltimate().name();
}
-const Fortran::semantics::SourceName ReductionProcessor::getRealName(
- const Fortran::parser::ProcedureDesignator &pd) {
- const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
- assert(name && "Invalid Reduction Intrinsic.");
- return getRealName(name);
+const Fortran::semantics::SourceName
+ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
+ return getRealName(pd.v.id());
}
int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 00770fe81d1ef6..855e2aa4ad13cd 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -13,6 +13,7 @@
#ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H
#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
+#include "Clauses.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
@@ -57,25 +58,25 @@ class ReductionProcessor {
};
static ReductionIdentifier
- getReductionType(const Fortran::parser::ProcedureDesignator &pd);
+ getReductionType(const omp::clause::ProcedureDesignator &pd);
- static ReductionIdentifier getReductionType(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp);
+ static ReductionIdentifier
+ getReductionType(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp);
- static bool supportedIntrinsicProcReduction(
- const Fortran::parser::ProcedureDesignator &pd);
+ static bool
+ supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd);
static const Fortran::semantics::SourceName
- getRealName(const Fortran::parser::Name *name);
+ getRealName(const Fortran::semantics::Symbol *symbol);
static const Fortran::semantics::SourceName
- getRealName(const Fortran::parser::ProcedureDesignator &pd);
+ getRealName(const omp::clause::ProcedureDesignator &pd);
static std::string getReductionName(llvm::StringRef name, mlir::Type ty);
- static std::string getReductionName(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type ty);
+ static std::string
+ getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
+ mlir::Type ty);
/// This function returns the identity value of the operator \p
/// reductionOpName. For example:
@@ -112,7 +113,7 @@ class ReductionProcessor {
static void
addReductionDecl(mlir::Location currentLocation,
Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpReductionClause &reduction,
+ const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 31b15257d18687..9a6a28ded7006d 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "Utils.h"
+#include "Clauses.h"
#include <flang/Lower/AbstractConverter.h>
#include <flang/Lower/ConvertType.h>
@@ -28,9 +29,27 @@ namespace Fortran {
namespace lower {
namespace omp {
-void genObjectList(const Fortran::parser::OmpObjectList &objectList,
+void genObjectList(const ObjectList &objects,
Fortran::lower::AbstractConverter &converter,
llvm::SmallVectorImpl<mlir::Value> &operands) {
+ for (const Object &object : objects) {
+ const Fortran::semantics::Symbol *sym = object.id();
+ assert(sym && "Expected Symbol");
+ if (mlir::Value variable = converter.getSymbolAddress(*sym)) {
+ operands.push_back(variable);
+ } else {
+ if (const auto *details =
+ sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
+ operands.push_back(converter.getSymbolAddress(details->symbol()));
+ converter.copySymbolBinding(details->symbol(), *sym);
+ }
+ }
+ }
+}
+
+void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
+ Fortran::lower::AbstractConverter &converter,
+ llvm::SmallVectorImpl<mlir::Value> &operands) {
auto addOperands = [&](Fortran::lower::SymbolRef sym) {
const mlir::Value variable = converter.getSymbolAddress(sym);
if (variable) {
@@ -50,24 +69,10 @@ void genObjectList(const Fortran::parser::OmpObjectList &objectList,
}
void gatherFuncAndVarSyms(
- const Fortran::parser::OmpObjectList &objList,
- mlir::omp::DeclareTargetCaptureClause clause,
+ const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
- for (const Fortran::parser::OmpObject &ompObject : objList.v) {
- Fortran::common::visit(
- Fortran::common::visitors{
- [&](const Fortran::parser::Designator &designator) {
- if (const Fortran::parser::Name *name =
- Fortran::semantics::getDesignatorNameIfDataRef(
- designator)) {
- symbolAndClause.emplace_back(clause, *name->symbol);
- }
- },
- [&](const Fortran::parser::Name &name) {
- symbolAndClause.emplace_back(clause, *name.symbol);
- }},
- ompObject.u);
- }
+ for (const Object &object : objects)
+ symbolAndClause.emplace_back(clause, *object.id());
}
Fortran::semantics::Symbol *
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index c346f891f0797e..4ab4bc9c137071 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -9,6 +9,7 @@
#ifndef FORTRAN_LOWER_OPENMPUTILS_H
#define FORTRAN_LOWER_OPENMPUTILS_H
+#include "Clauses.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Value.h"
@@ -50,17 +51,20 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
bool isVal = false);
void gatherFuncAndVarSyms(
- const Fortran::parser::OmpObjectList &objList,
- mlir::omp::DeclareTargetCaptureClause clause,
+ const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause);
Fortran::semantics::Symbol *
getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject);
-void genObjectList(const Fortran::parser::OmpObjectList &objectList,
+void genObjectList(const ObjectList &objects,
Fortran::lower::AbstractConverter &converter,
llvm::SmallVectorImpl<mlir::Value> &operands);
+void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
+ Fortran::lower::AbstractConverter &converter,
+ llvm::SmallVectorImpl<mlir::Value> &operands);
+
} // namespace omp
} // namespace lower
} // namespace Fortran
More information about the llvm-branch-commits
mailing list