[llvm-branch-commits] [flang] [flang][OpenMP] Convert repeatable clauses (except Map) in ClauseProc… (PR #81623)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Feb 13 08:32:37 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
Author: Krzysztof Parzyszek (kparzysz)
<details>
<summary>Changes</summary>
…essor
Rename `findRepeatableClause` to `findRepeatableClause2`, and make the new `findRepeatableClause` operate on new `omp::Clause` objects.
Leave `Map` unchanged, because it will require more changes for it to work.
---
Patch is 51.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81623.diff
2 Files Affected:
- (modified) flang/include/flang/Evaluate/tools.h (+23)
- (modified) flang/lib/Lower/OpenMP.cpp (+305-327)
``````````diff
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index d257da1a709642..e9999974944e88 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -430,6 +430,29 @@ template <typename A> std::optional<CoarrayRef> ExtractCoarrayRef(const A &x) {
}
}
+struct ExtractSubstringHelper {
+ template <typename T> static std::optional<Substring> visit(T &&) {
+ return std::nullopt;
+ }
+
+ static std::optional<Substring> visit(const Substring &e) { return e; }
+
+ template <typename T>
+ static std::optional<Substring> visit(const Designator<T> &e) {
+ return std::visit([](auto &&s) { return visit(s); }, e.u);
+ }
+
+ template <typename T>
+ static std::optional<Substring> visit(const Expr<T> &e) {
+ return std::visit([](auto &&s) { return visit(s); }, e.u);
+ }
+};
+
+template <typename A>
+std::optional<Substring> ExtractSubstring(const A &x) {
+ return ExtractSubstringHelper::visit(x);
+}
+
// If an expression is simply a whole symbol data designator,
// extract and return that symbol, else null.
template <typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) {
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index d7a93db15a4bb8..4b21ab934c9393 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -72,9 +72,9 @@ getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
return sym;
}
-static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
- Fortran::lower::AbstractConverter &converter,
- llvm::SmallVectorImpl<mlir::Value> &operands) {
+static void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
+ Fortran::lower::AbstractConverter &converter,
+ llvm::SmallVectorImpl<mlir::Value> &operands) {
auto addOperands = [&](Fortran::lower::SymbolRef sym) {
const mlir::Value variable = converter.getSymbolAddress(sym);
if (variable) {
@@ -93,27 +93,6 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
}
}
-static void gatherFuncAndVarSyms(
- const Fortran::parser::OmpObjectList &objList,
- mlir::omp::DeclareTargetCaptureClause clause,
- llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
- for (const Fortran::parser::OmpObject &ompObject : objList.v) {
- Fortran::common::visit(
- Fortran::common::visitors{
- [&](const Fortran::parser::Designator &designator) {
- if (const Fortran::parser::Name *name =
- Fortran::semantics::getDesignatorNameIfDataRef(
- designator)) {
- symbolAndClause.emplace_back(clause, *name->symbol);
- }
- },
- [&](const Fortran::parser::Name &name) {
- symbolAndClause.emplace_back(clause, *name.symbol);
- }},
- ompObject.u);
- }
-}
-
static Fortran::lower::pft::Evaluation *
getCollapsedLoopEval(Fortran::lower::pft::Evaluation &eval, int collapseValue) {
// Return the Evaluation of the innermost collapsed loop, or the current one
@@ -1257,6 +1236,32 @@ List<Clause> makeList(const parser::OmpClauseList &clauses,
}
} // namespace omp
+static void genObjectList(const omp::ObjectList &objects,
+ Fortran::lower::AbstractConverter &converter,
+ llvm::SmallVectorImpl<mlir::Value> &operands) {
+ for (const omp::Object &object : objects) {
+ const Fortran::semantics::Symbol *sym = object.sym;
+ assert(sym && "Expected Symbol");
+ if (mlir::Value variable = converter.getSymbolAddress(*sym)) {
+ operands.push_back(variable);
+ } else {
+ if (const auto *details =
+ sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
+ operands.push_back(converter.getSymbolAddress(details->symbol()));
+ converter.copySymbolBinding(details->symbol(), *sym);
+ }
+ }
+ }
+}
+
+static void gatherFuncAndVarSyms(
+ const omp::ObjectList &objects,
+ mlir::omp::DeclareTargetCaptureClause clause,
+ llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
+ for (const omp::Object &object : objects)
+ symbolAndClause.emplace_back(clause, *object.sym);
+}
+
//===----------------------------------------------------------------------===//
// DataSharingProcessor
//===----------------------------------------------------------------------===//
@@ -1718,9 +1723,8 @@ class ClauseProcessor {
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
- bool
- processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
- mlir::Value &result) const;
+ bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
+ mlir::Value &result) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
@@ -1815,6 +1819,26 @@ class ClauseProcessor {
/// if at least one instance was found.
template <typename T>
bool findRepeatableClause(
+ std::function<void(const T &, const Fortran::parser::CharBlock &source)>
+ callbackFn) const {
+ bool found = false;
+ ClauseIterator nextIt, endIt = clauses.end();
+ for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
+ nextIt = findClause<T>(it, endIt);
+
+ if (nextIt != endIt) {
+ callbackFn(std::get<T>(nextIt->u), nextIt->source);
+ found = true;
+ ++nextIt;
+ }
+ }
+ return found;
+ }
+
+ /// Call `callbackFn` for each occurrence of the given clause. Return `true`
+ /// if at least one instance was found.
+ template <typename T>
+ bool findRepeatableClause2(
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
callbackFn) const {
bool found = false;
@@ -1880,9 +1904,9 @@ class ReductionProcessor {
IEOR
};
static ReductionIdentifier
- getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
+ getReductionType(const omp::clause::ProcedureDesignator &pd) {
auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
- getRealName(pd).ToString())
+ getRealName(pd.v.sym).ToString())
.Case("max", ReductionIdentifier::MAX)
.Case("min", ReductionIdentifier::MIN)
.Case("iand", ReductionIdentifier::IAND)
@@ -1894,35 +1918,33 @@ class ReductionProcessor {
}
static ReductionIdentifier getReductionType(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) {
+ omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Add:
return ReductionIdentifier::ADD;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
return ReductionIdentifier::SUBTRACT;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
return ReductionIdentifier::MULTIPLY;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ case omp::clause::DefinedOperator::IntrinsicOperator::AND:
return ReductionIdentifier::AND;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
return ReductionIdentifier::EQV;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ case omp::clause::DefinedOperator::IntrinsicOperator::OR:
return ReductionIdentifier::OR;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
return ReductionIdentifier::NEQV;
default:
llvm_unreachable("unexpected intrinsic operator in reduction");
}
}
- static bool supportedIntrinsicProcReduction(
- const Fortran::parser::ProcedureDesignator &pd) {
- const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
- assert(name && "Invalid Reduction Intrinsic.");
- if (!name->symbol->GetUltimate().attrs().test(
- Fortran::semantics::Attr::INTRINSIC))
+ static bool
+ supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd) {
+ Fortran::semantics::Symbol *sym = pd.v.sym;
+ if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC))
return false;
- auto redType = llvm::StringSwitch<bool>(getRealName(name).ToString())
+ auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
.Case("max", true)
.Case("min", true)
.Case("iand", true)
@@ -1933,15 +1955,13 @@ class ReductionProcessor {
}
static const Fortran::semantics::SourceName
- getRealName(const Fortran::parser::Name *name) {
- return name->symbol->GetUltimate().name();
+ getRealName(const Fortran::semantics::Symbol *symbol) {
+ return symbol->GetUltimate().name();
}
static const Fortran::semantics::SourceName
- getRealName(const Fortran::parser::ProcedureDesignator &pd) {
- const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
- assert(name && "Invalid Reduction Intrinsic.");
- return getRealName(name);
+ getRealName(const omp::clause::ProcedureDesignator &pd) {
+ return getRealName(pd.v.sym);
}
static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
@@ -1951,25 +1971,25 @@ class ReductionProcessor {
.str();
}
- static std::string getReductionName(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type ty) {
+ static std::string
+ getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
+ mlir::Type ty) {
std::string reductionName;
switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Add:
reductionName = "add_reduction";
break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
reductionName = "multiply_reduction";
break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ case omp::clause::DefinedOperator::IntrinsicOperator::AND:
return "and_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
return "eqv_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ case omp::clause::DefinedOperator::IntrinsicOperator::OR:
return "or_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
return "neqv_reduction";
default:
reductionName = "other_reduction";
@@ -2213,7 +2233,7 @@ class ReductionProcessor {
static void
addReductionDecl(mlir::Location currentLocation,
Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpReductionClause &reduction,
+ const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
@@ -2221,13 +2241,12 @@ class ReductionProcessor {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::omp::ReductionDeclareOp decl;
const auto &redOperator{
- std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
- const auto &objectList{
- std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+ std::get<omp::clause::ReductionOperator>(reduction.t)};
+ const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
if (const auto &redDefinedOp =
- std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+ std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
const auto &intrinsicOp{
- std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+ std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
redDefinedOp->u)};
ReductionIdentifier redId = getReductionType(intrinsicOp);
switch (redId) {
@@ -2243,10 +2262,41 @@ class ReductionProcessor {
"Reduction of some intrinsic operators is not supported");
break;
}
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ for (const omp::Object &object : objectList) {
+ if (const Fortran::semantics::Symbol *symbol = object.sym) {
+ if (reductionSymbols)
+ reductionSymbols->push_back(symbol);
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+ symVal = declOp.getBase();
+ mlir::Type redType =
+ symVal.getType().cast<fir::ReferenceType>().getEleTy();
+ reductionVars.push_back(symVal);
+ if (redType.isa<fir::LogicalType>())
+ decl = createReductionDecl(
+ firOpBuilder,
+ getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
+ redType, currentLocation);
+ else if (redType.isIntOrIndexOrFloat()) {
+ decl = createReductionDecl(firOpBuilder,
+ getReductionName(intrinsicOp, redType),
+ redId, redType, currentLocation);
+ } else {
+ TODO(currentLocation, "Reduction of some types is not supported");
+ }
+ reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+ firOpBuilder.getContext(), decl.getSymName()));
+ }
+ }
+ } else if (const auto *reductionIntrinsic =
+ std::get_if<omp::clause::ProcedureDesignator>(
+ &redOperator.u)) {
+ if (ReductionProcessor::supportedIntrinsicProcReduction(
+ *reductionIntrinsic)) {
+ ReductionProcessor::ReductionIdentifier redId =
+ ReductionProcessor::getReductionType(*reductionIntrinsic);
+ for (const omp::Object &object : objectList) {
+ if (const Fortran::semantics::Symbol *symbol = object.sym) {
if (reductionSymbols)
reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
@@ -2255,55 +2305,18 @@ class ReductionProcessor {
mlir::Type redType =
symVal.getType().cast<fir::ReferenceType>().getEleTy();
reductionVars.push_back(symVal);
- if (redType.isa<fir::LogicalType>())
- decl = createReductionDecl(
- firOpBuilder,
- getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
- redId, redType, currentLocation);
- else if (redType.isIntOrIndexOrFloat()) {
- decl = createReductionDecl(firOpBuilder,
- getReductionName(intrinsicOp, redType),
- redId, redType, currentLocation);
- } else {
- TODO(currentLocation, "Reduction of some types is not supported");
- }
+ assert(redType.isIntOrIndexOrFloat() &&
+ "Unsupported reduction type");
+ decl = createReductionDecl(
+ firOpBuilder,
+ getReductionName(getRealName(*reductionIntrinsic).ToString(),
+ redType),
+ redId, redType, currentLocation);
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
firOpBuilder.getContext(), decl.getSymName()));
}
}
}
- } else if (const auto *reductionIntrinsic =
- std::get_if<Fortran::parser::ProcedureDesignator>(
- &redOperator.u)) {
- if (ReductionProcessor::supportedIntrinsicProcReduction(
- *reductionIntrinsic)) {
- ReductionProcessor::ReductionIdentifier redId =
- ReductionProcessor::getReductionType(*reductionIntrinsic);
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- if (reductionSymbols)
- reductionSymbols->push_back(symbol);
- mlir::Value symVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
- symVal = declOp.getBase();
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- assert(redType.isIntOrIndexOrFloat() &&
- "Unsupported reduction type");
- decl = createReductionDecl(
- firOpBuilder,
- getReductionName(getRealName(*reductionIntrinsic).ToString(),
- redType),
- redId, redType, currentLocation);
- reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
- firOpBuilder.getContext(), decl.getSymName()));
- }
- }
- }
- }
}
}
};
@@ -2365,7 +2378,7 @@ getSimdModifier(const omp::clause::Schedule &clause) {
static void
genAllocateClause(Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpAllocateClause &ompAllocateClause,
+ const omp::clause::Allocate &clause,
llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
llvm::SmallVectorImpl<mlir::Value> &allocateOperands) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -2373,21 +2386,18 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext stmtCtx;
mlir::Value allocatorOperand;
- const Fortran::parser::OmpObjectList &ompObjectList =
- std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t);
- const auto &allocateModifier = std::get<
- std::optional<Fortran::parser::OmpAllocateClause::AllocateModifier>>(
- ompAllocateClause.t);
+ const omp::ObjectList &objectList = std::get<omp::ObjectList>(clause.t);
+ const auto &modifier =
+ std::get<std::optional<omp::clause::Allocate::Modifier>>(clause.t);
// If the allocate modifier is present, check if we only use the allocator
// submodifier. ALIGN in this context is unimplemented
const bool onlyAllocator =
- allocateModifier &&
- std::holds_alternative<
- Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
- allocateModifier->u);
+ modifier &&
+ std::holds_alternative<omp::clause::Allocate::Modifier::Allocator>(
+ modifier->u);
- if (allocateModifier && !onlyAllocator) {
+ if (modifier && !onlyAllocator) {
TODO(currentLocation, "OmpAllocateClause ALIGN modifier");
}
@@ -2395,20 +2405,17 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter,
// to list of allocators, otherwise, add default allocator to
// list of allocators.
if (onlyAllocator) {
- const auto &allocatorValue = std::get<
- Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
- allocateModifier->u);
- allocatorOperand = fir...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/81623
More information about the llvm-branch-commits
mailing list