[llvm-branch-commits] [flang] [flang][OpenMP] Convert DataSharingProcessor to omp::Clause (PR #81629)
Krzysztof Parzyszek via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Feb 19 12:04:42 PST 2024
https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/81629
>From 2f6d9000651e34d5a038671fcd061279f7a80354 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Mon, 5 Feb 2024 11:30:54 -0600
Subject: [PATCH 1/6] [flang][OpenMP] Implement flexible OpenMP clause
representation
The new set of classes representing OpenMP classes mimics the
contents of parser::OmpClause, but differs in a few aspects:
- it can be easily created, copied, etc.
- is based on semantics::SomeExpr instead of parser objects.
The class `OmpObject` is represented by `omp::Object`, which contains
the symbol associated with the object, and semantics::MaybeExpr
representing the designator for the symbol reference.
This patch only introduces the new classes, they are not yet used
anywhere.
---
flang/lib/Lower/OpenMP.cpp | 1116 ++++++++++++++++++++++++++++++++++++
1 file changed, 1116 insertions(+)
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 9397af8b8bd05e..5ddc1eaa27e003 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -142,6 +142,1122 @@ static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter,
converter.genEval(e);
}
+//===----------------------------------------------------------------------===//
+// Clauses
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+template <typename C> //
+llvm::omp::Clause getClauseIdForClass(C &&) {
+ using namespace Fortran;
+ using A = llvm::remove_cvref_t<C>; // A is referenced in OMP.inc
+ // The code included below contains a sequence of checks like the following
+ // for each OpenMP clause
+ // if constexpr (std::is_same_v<A, parser::OmpClause::AcqRel>)
+ // return llvm::omp::Clause::OMPC_acq_rel;
+ // [...]
+#define GEN_FLANG_CLAUSE_PARSER_KIND_MAP
+#include "llvm/Frontend/OpenMP/OMP.inc"
+}
+} // namespace detail
+
+static llvm::omp::Clause getClauseId(const Fortran::parser::OmpClause &clause) {
+ return std::visit([](auto &&s) { return detail::getClauseIdForClass(s); },
+ clause.u);
+}
+
+namespace omp {
+using namespace Fortran;
+using SomeType = evaluate::SomeType;
+using SomeExpr = semantics::SomeExpr;
+using MaybeExpr = semantics::MaybeExpr;
+
+template <typename T> //
+using List = std::vector<T>;
+
+struct SymDsgExtractor {
+ using SymDsg = std::tuple<semantics::Symbol *, MaybeExpr>;
+
+ template <typename T> //
+ static T &&AsRvalueRef(T &&t) {
+ return std::move(t);
+ }
+ template <typename T> //
+ static T AsRvalueRef(const T &t) {
+ return t;
+ }
+
+ template <typename T> //
+ static SymDsg visit(T &&) {
+ // Use this to see missing overloads:
+ // llvm::errs() << "NULL: " << __PRETTY_FUNCTION__ << '\n';
+ return SymDsg{};
+ }
+
+ template <typename T> //
+ static SymDsg visit(const evaluate::Designator<T> &e) {
+ // Symbols cannot be created after semantic checks, so all symbol
+ // pointers that are non-null must point to one of those pre-existing
+ // objects. Throughout the code, symbols are often pointed to by
+ // non-const pointers, so there is no harm in casting the constness
+ // away.
+ return std::make_tuple(const_cast<semantics::Symbol *>(e.GetLastSymbol()),
+ evaluate::AsGenericExpr(AsRvalueRef(e)));
+ }
+
+ static SymDsg visit(const evaluate::ProcedureDesignator &e) {
+ // See comment above regarding const_cast.
+ return std::make_tuple(const_cast<semantics::Symbol *>(e.GetSymbol()),
+ std::nullopt);
+ }
+
+ template <typename T> //
+ static SymDsg visit(const evaluate::Expr<T> &e) {
+ return std::visit([](auto &&s) { return visit(s); }, e.u);
+ }
+
+ static bool verify(const SymDsg &sd) {
+ const semantics::Symbol *symbol = std::get<0>(sd);
+ assert(symbol && "Expecting Symbol");
+ auto &maybeDsg = std::get<1>(sd);
+ if (!maybeDsg)
+ return true;
+ std::optional<evaluate::DataRef> maybeRef =
+ evaluate::ExtractDataRef(*maybeDsg);
+ if (maybeRef) {
+ assert(&maybeRef->GetLastSymbol() == symbol &&
+ "Designator not for symbol");
+ return true;
+ }
+
+ // This could still be a Substring or ComplexPart, but at least Substring
+ // is not allowed in OpenMP.
+ maybeDsg->dump();
+ llvm_unreachable("Expecting DataRef");
+ }
+};
+
+SymDsgExtractor::SymDsg getSymbolAndDesignator(const MaybeExpr &expr) {
+ if (!expr)
+ return SymDsgExtractor::SymDsg{};
+ return std::visit([](auto &&s) { return SymDsgExtractor::visit(s); },
+ expr->u);
+}
+
+struct Object {
+ semantics::Symbol *sym; // symbol
+ MaybeExpr dsg; // designator ending with symbol
+};
+
+using ObjectList = List<Object>;
+
+Object makeObject(const parser::OmpObject &object,
+ semantics::SemanticsContext &semaCtx) {
+ // If object is a common block, expression analyzer won't be able to
+ // do anything.
+ if (const auto *name = std::get_if<parser::Name>(&object.u)) {
+ assert(name->symbol && "Expecting Symbol");
+ return Object{name->symbol, std::nullopt};
+ }
+ evaluate::ExpressionAnalyzer ea{semaCtx};
+ SymDsgExtractor::SymDsg sd = std::visit(
+ [&](auto &&s) { return getSymbolAndDesignator(ea.Analyze(s)); },
+ object.u);
+ SymDsgExtractor::verify(sd);
+ return Object{std::get<0>(sd), std::move(std::get<1>(sd))};
+}
+
+Object makeObject(const parser::Name &name,
+ semantics::SemanticsContext &semaCtx) {
+ assert(name.symbol && "Expecting Symbol");
+ return Object{name.symbol, std::nullopt};
+}
+
+Object makeObject(const parser::Designator &dsg,
+ semantics::SemanticsContext &semaCtx) {
+ evaluate::ExpressionAnalyzer ea{semaCtx};
+ SymDsgExtractor::SymDsg sd = getSymbolAndDesignator(ea.Analyze(dsg));
+ SymDsgExtractor::verify(sd);
+ return Object{std::get<0>(sd), std::move(std::get<1>(sd))};
+}
+
+Object makeObject(const parser::StructureComponent &comp,
+ semantics::SemanticsContext &semaCtx) {
+ evaluate::ExpressionAnalyzer ea{semaCtx};
+ SymDsgExtractor::SymDsg sd = getSymbolAndDesignator(ea.Analyze(comp));
+ SymDsgExtractor::verify(sd);
+ return Object{std::get<0>(sd), std::move(std::get<1>(sd))};
+}
+
+auto makeObject(semantics::SemanticsContext &semaCtx) {
+ return [&](auto &&s) { return makeObject(s, semaCtx); };
+}
+
+template <typename T>
+SomeExpr makeExpr(T &&inp, semantics::SemanticsContext &semaCtx) {
+ auto maybeExpr = evaluate::ExpressionAnalyzer(semaCtx).Analyze(inp);
+ assert(maybeExpr);
+ return std::move(*maybeExpr);
+}
+
+auto makeExpr(semantics::SemanticsContext &semaCtx) {
+ return [&](auto &&s) { return makeExpr(s, semaCtx); };
+}
+
+template <typename C, typename F,
+ typename E = typename llvm::remove_cvref_t<C>::value_type,
+ typename R = std::invoke_result_t<F, E>>
+List<R> makeList(C &&container, F &&func) {
+ List<R> v;
+ llvm::transform(container, std::back_inserter(v), func);
+ return v;
+}
+
+ObjectList makeList(const parser::OmpObjectList &objects,
+ semantics::SemanticsContext &semaCtx) {
+ return makeList(objects.v, makeObject(semaCtx));
+}
+
+template <typename U, typename T> //
+U enum_cast(T t) {
+ using BareT = llvm::remove_cvref_t<T>;
+ using BareU = llvm::remove_cvref_t<U>;
+ static_assert(std::is_enum_v<BareT> && std::is_enum_v<BareU>);
+
+ return U{static_cast<std::underlying_type_t<BareT>>(t)};
+}
+
+template <typename F, typename T, typename U = std::invoke_result_t<F, T>>
+std::optional<U> maybeApply(F &&func, const std::optional<T> &inp) {
+ if (!inp)
+ return std::nullopt;
+ return std::move(func(*inp));
+}
+
+namespace clause {
+#ifdef EMPTY_CLASS
+#undef EMPTY_CLASS
+#endif
+#define EMPTY_CLASS(cls) \
+ struct cls { \
+ using EmptyTrait = std::true_type; \
+ }; \
+ cls make(const parser::OmpClause::cls &, semantics::SemanticsContext &) { \
+ return cls{}; \
+ }
+
+#ifdef WRAPPER_CLASS
+#undef WRAPPER_CLASS
+#endif
+#define WRAPPER_CLASS(cls, content) // Nothing
+#define GEN_FLANG_CLAUSE_PARSER_CLASSES
+#include "llvm/Frontend/OpenMP/OMP.inc"
+#undef EMPTY_CLASS
+
+// Helper objects
+
+struct DefinedOperator {
+ struct DefinedOpName {
+ using WrapperTrait = std::true_type;
+ Object v;
+ };
+ ENUM_CLASS(IntrinsicOperator, Power, Multiply, Divide, Add, Subtract, Concat,
+ LT, LE, EQ, NE, GE, GT, NOT, AND, OR, EQV, NEQV)
+ using UnionTrait = std::true_type;
+ std::variant<DefinedOpName, IntrinsicOperator> u;
+};
+
+DefinedOperator makeDefOp(const parser::DefinedOperator &inp,
+ semantics::SemanticsContext &semaCtx) {
+ return DefinedOperator{
+ std::visit(common::visitors{
+ [&](const parser::DefinedOpName &s) {
+ return DefinedOperator{DefinedOperator::DefinedOpName{
+ makeObject(s.v, semaCtx)}};
+ },
+ [&](const parser::DefinedOperator::IntrinsicOperator &s) {
+ return DefinedOperator{
+ enum_cast<DefinedOperator::IntrinsicOperator>(s)};
+ },
+ },
+ inp.u),
+ };
+}
+
+struct ProcedureDesignator {
+ using WrapperTrait = std::true_type;
+ Object v;
+};
+
+ProcedureDesignator makeProcDsg(const parser::ProcedureDesignator &inp,
+ semantics::SemanticsContext &semaCtx) {
+ return ProcedureDesignator{std::visit(
+ common::visitors{
+ [&](const parser::Name &t) { return makeObject(t, semaCtx); },
+ [&](const parser::ProcComponentRef &t) {
+ return makeObject(t.v.thing, semaCtx);
+ },
+ },
+ inp.u)};
+}
+
+struct ReductionOperator {
+ using UnionTrait = std::true_type;
+ std::variant<DefinedOperator, ProcedureDesignator> u;
+};
+
+ReductionOperator makeRedOp(const parser::OmpReductionOperator &inp,
+ semantics::SemanticsContext &semaCtx) {
+ return std::visit(common::visitors{
+ [&](const parser::DefinedOperator &s) {
+ return ReductionOperator{makeDefOp(s, semaCtx)};
+ },
+ [&](const parser::ProcedureDesignator &s) {
+ return ReductionOperator{makeProcDsg(s, semaCtx)};
+ },
+ },
+ inp.u);
+}
+
+// Actual clauses. Each T (where OmpClause::T exists) has its "make".
+
+struct Aligned {
+ using TupleTrait = std::true_type;
+ std::tuple<ObjectList, MaybeExpr> t;
+};
+
+Aligned make(const parser::OmpClause::Aligned &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpAlignedClause
+ auto &t0 = std::get<parser::OmpObjectList>(inp.v.t);
+ auto &t1 = std::get<std::optional<parser::ScalarIntConstantExpr>>(inp.v.t);
+
+ return Aligned{{
+ makeList(t0, semaCtx),
+ maybeApply(makeExpr(semaCtx), t1),
+ }};
+}
+
+struct Allocate {
+ struct Modifier {
+ struct Allocator {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+ };
+ struct Align {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+ };
+ struct ComplexModifier {
+ using TupleTrait = std::true_type;
+ std::tuple<Allocator, Align> t;
+ };
+ using UnionTrait = std::true_type;
+ std::variant<Allocator, ComplexModifier, Align> u;
+ };
+ using TupleTrait = std::true_type;
+ std::tuple<std::optional<Modifier>, ObjectList> t;
+};
+
+Allocate make(const parser::OmpClause::Allocate &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpAllocateClause
+ using wrapped = parser::OmpAllocateClause;
+ auto &t0 = std::get<std::optional<wrapped::AllocateModifier>>(inp.v.t);
+ auto &t1 = std::get<parser::OmpObjectList>(inp.v.t);
+
+ auto convert = [&](auto &&s) -> Allocate::Modifier {
+ using Modifier = Allocate::Modifier;
+ using Allocator = Modifier::Allocator;
+ using Align = Modifier::Align;
+ using ComplexModifier = Modifier::ComplexModifier;
+
+ return Modifier{
+ std::visit(
+ common::visitors{
+ [&](const wrapped::AllocateModifier::Allocator &v) {
+ return Modifier{Allocator{makeExpr(v.v, semaCtx)}};
+ },
+ [&](const wrapped::AllocateModifier::ComplexModifier &v) {
+ auto &s0 =
+ std::get<wrapped::AllocateModifier::Allocator>(v.t);
+ auto &s1 = std::get<wrapped::AllocateModifier::Align>(v.t);
+ return Modifier{ComplexModifier{{
+ Allocator{makeExpr(s0.v, semaCtx)},
+ Align{makeExpr(s1.v, semaCtx)},
+ }}};
+ },
+ [&](const wrapped::AllocateModifier::Align &v) {
+ return Modifier{Align{makeExpr(v.v, semaCtx)}};
+ },
+ },
+ s.u),
+ };
+ };
+
+ return Allocate{{maybeApply(convert, t0), makeList(t1, semaCtx)}};
+}
+
+struct Allocator {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+Allocator make(const parser::OmpClause::Allocator &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ScalarIntExpr
+ return Allocator{makeExpr(inp.v, semaCtx)};
+}
+
+struct AtomicDefaultMemOrder {
+ using WrapperTrait = std::true_type;
+ common::OmpAtomicDefaultMemOrderType v;
+};
+
+AtomicDefaultMemOrder make(const parser::OmpClause::AtomicDefaultMemOrder &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpAtomicDefaultMemOrderClause
+ return AtomicDefaultMemOrder{inp.v.v};
+}
+
+struct Collapse {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+Collapse make(const parser::OmpClause::Collapse &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ScalarIntConstantExpr
+ return Collapse{makeExpr(inp.v, semaCtx)};
+}
+
+struct Copyin {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+Copyin make(const parser::OmpClause::Copyin &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpObjectList
+ return Copyin{makeList(inp.v, semaCtx)};
+}
+
+struct Copyprivate {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+Copyprivate make(const parser::OmpClause::Copyprivate &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpObjectList
+ return Copyprivate{makeList(inp.v, semaCtx)};
+}
+
+struct Defaultmap {
+ ENUM_CLASS(ImplicitBehavior, Alloc, To, From, Tofrom, Firstprivate, None,
+ Default)
+ ENUM_CLASS(VariableCategory, Scalar, Aggregate, Allocatable, Pointer)
+ using TupleTrait = std::true_type;
+ std::tuple<ImplicitBehavior, std::optional<VariableCategory>> t;
+};
+
+Defaultmap make(const parser::OmpClause::Defaultmap &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpDefaultmapClause
+ using wrapped = parser::OmpDefaultmapClause;
+
+ auto convert = [](auto &&s) -> Defaultmap::VariableCategory {
+ return enum_cast<Defaultmap::VariableCategory>(s);
+ };
+ auto &t0 = std::get<wrapped::ImplicitBehavior>(inp.v.t);
+ auto &t1 = std::get<std::optional<wrapped::VariableCategory>>(inp.v.t);
+ auto v0 = enum_cast<Defaultmap::ImplicitBehavior>(t0);
+ return Defaultmap{{v0, maybeApply(convert, t1)}};
+}
+
+struct Default {
+ ENUM_CLASS(Type, Private, Firstprivate, Shared, None)
+ using WrapperTrait = std::true_type;
+ Type v;
+};
+
+Default make(const parser::OmpClause::Default &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpDefaultClause
+ return Default{enum_cast<Default::Type>(inp.v.v)};
+}
+
+struct Depend {
+ struct Source {
+ using EmptyTrait = std::true_type;
+ };
+ struct Sink {
+ using Length = std::tuple<DefinedOperator, SomeExpr>;
+ using Vec = std::tuple<Object, std::optional<Length>>;
+ using WrapperTrait = std::true_type;
+ List<Vec> v;
+ };
+ ENUM_CLASS(Type, In, Out, Inout, Source, Sink)
+ struct InOut {
+ using TupleTrait = std::true_type;
+ std::tuple<Type, ObjectList> t;
+ };
+ using UnionTrait = std::true_type;
+ std::variant<Source, Sink, InOut> u;
+};
+
+Depend make(const parser::OmpClause::Depend &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpDependClause
+ using wrapped = parser::OmpDependClause;
+
+ return std::visit(
+ common::visitors{
+ [&](const wrapped::Source &s) { return Depend{Depend::Source{}}; },
+ [&](const wrapped::Sink &s) {
+ auto convert = [&](const parser::OmpDependSinkVec &v) {
+ auto &t0 = std::get<parser::Name>(v.t);
+ auto &t1 =
+ std::get<std::optional<parser::OmpDependSinkVecLength>>(v.t);
+ auto convert1 = [&](const parser::OmpDependSinkVecLength &u) {
+ auto &s0 = std::get<parser::DefinedOperator>(u.t);
+ auto &s1 = std::get<parser::ScalarIntConstantExpr>(u.t);
+ return Depend::Sink::Length{makeDefOp(s0, semaCtx),
+ makeExpr(s1, semaCtx)};
+ };
+ return Depend::Sink::Vec{makeObject(t0, semaCtx),
+ maybeApply(convert1, t1)};
+ };
+ return Depend{Depend::Sink{makeList(s.v, convert)}};
+ },
+ [&](const wrapped::InOut &s) {
+ auto &t0 = std::get<parser::OmpDependenceType>(s.t);
+ auto &t1 = std::get<std::list<parser::Designator>>(s.t);
+ auto convert = [&](const parser::Designator &t) {
+ return makeObject(t, semaCtx);
+ };
+ return Depend{Depend::InOut{
+ {enum_cast<Depend::Type>(t0.v), makeList(t1, convert)}}};
+ },
+ },
+ inp.v.u);
+}
+
+struct Device {
+ ENUM_CLASS(DeviceModifier, Ancestor, Device_Num)
+ using TupleTrait = std::true_type;
+ std::tuple<std::optional<DeviceModifier>, SomeExpr> t;
+};
+
+Device make(const parser::OmpClause::Device &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpDeviceClause
+ using wrapped = parser::OmpDeviceClause;
+
+ auto convert = [](auto &&s) -> Device::DeviceModifier {
+ return enum_cast<Device::DeviceModifier>(s);
+ };
+ auto &t0 = std::get<std::optional<wrapped::DeviceModifier>>(inp.v.t);
+ auto &t1 = std::get<parser::ScalarIntExpr>(inp.v.t);
+ return Device{{maybeApply(convert, t0), makeExpr(t1, semaCtx)}};
+}
+
+struct DeviceType {
+ ENUM_CLASS(Type, Any, Host, Nohost)
+ using WrapperTrait = std::true_type;
+ Type v;
+};
+
+DeviceType make(const parser::OmpClause::DeviceType &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpDeviceTypeClause
+ return DeviceType{enum_cast<DeviceType::Type>(inp.v.v)};
+}
+
+struct DistSchedule {
+ using WrapperTrait = std::true_type;
+ MaybeExpr v;
+};
+
+DistSchedule make(const parser::OmpClause::DistSchedule &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> std::optional<parser::ScalarIntExpr>
+ return DistSchedule{maybeApply(makeExpr(semaCtx), inp.v)};
+}
+
+struct Enter {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+Enter make(const parser::OmpClause::Enter &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpObjectList
+ return Enter{makeList(inp.v, semaCtx)};
+}
+
+struct Filter {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+Filter make(const parser::OmpClause::Filter &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ScalarIntExpr
+ return Filter{makeExpr(inp.v, semaCtx)};
+}
+
+struct Final {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+Final make(const parser::OmpClause::Final &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ScalarLogicalExpr
+ return Final{makeExpr(inp.v, semaCtx)};
+}
+
+struct Firstprivate {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+Firstprivate make(const parser::OmpClause::Firstprivate &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpObjectList
+ return Firstprivate{makeList(inp.v, semaCtx)};
+}
+
+struct From {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+From make(const parser::OmpClause::From &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpObjectList
+ return From{makeList(inp.v, semaCtx)};
+}
+
+struct Grainsize {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+Grainsize make(const parser::OmpClause::Grainsize &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ScalarIntExpr
+ return Grainsize{makeExpr(inp.v, semaCtx)};
+}
+
+struct HasDeviceAddr {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+HasDeviceAddr make(const parser::OmpClause::HasDeviceAddr &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpObjectList
+ return HasDeviceAddr{makeList(inp.v, semaCtx)};
+}
+
+struct Hint {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+Hint make(const parser::OmpClause::Hint &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ConstantExpr
+ return Hint{makeExpr(inp.v, semaCtx)};
+}
+
+struct If {
+ ENUM_CLASS(DirectiveNameModifier, Parallel, Simd, Target, TargetData,
+ TargetEnterData, TargetExitData, TargetUpdate, Task, Taskloop,
+ Teams)
+ using TupleTrait = std::true_type;
+ std::tuple<std::optional<DirectiveNameModifier>, SomeExpr> t;
+};
+
+If make(const parser::OmpClause::If &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpIfClause
+ using wrapped = parser::OmpIfClause;
+
+ auto &t0 = std::get<std::optional<wrapped::DirectiveNameModifier>>(inp.v.t);
+ auto &t1 = std::get<parser::ScalarLogicalExpr>(inp.v.t);
+ auto convert = [](auto &&s) -> If::DirectiveNameModifier {
+ return enum_cast<If::DirectiveNameModifier>(s);
+ };
+ return If{{maybeApply(convert, t0), makeExpr(t1, semaCtx)}};
+}
+
+struct InReduction {
+ using TupleTrait = std::true_type;
+ std::tuple<ReductionOperator, ObjectList> t;
+};
+
+InReduction make(const parser::OmpClause::InReduction &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpInReductionClause
+ auto &t0 = std::get<parser::OmpReductionOperator>(inp.v.t);
+ auto &t1 = std::get<parser::OmpObjectList>(inp.v.t);
+ return InReduction{{makeRedOp(t0, semaCtx), makeList(t1, semaCtx)}};
+}
+
+struct IsDevicePtr {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+IsDevicePtr make(const parser::OmpClause::IsDevicePtr &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpObjectList
+ return IsDevicePtr{makeList(inp.v, semaCtx)};
+}
+
+struct Lastprivate {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+Lastprivate make(const parser::OmpClause::Lastprivate &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpObjectList
+ return Lastprivate{makeList(inp.v, semaCtx)};
+}
+
+struct Linear {
+ struct Modifier {
+ ENUM_CLASS(Type, Ref, Val, Uval)
+ using WrapperTrait = std::true_type;
+ Type v;
+ };
+ using TupleTrait = std::true_type;
+ std::tuple<std::optional<Modifier>, ObjectList, MaybeExpr> t;
+};
+
+Linear make(const parser::OmpClause::Linear &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpLinearClause
+ using wrapped = parser::OmpLinearClause;
+
+ return std::visit(
+ common::visitors{
+ [&](const wrapped::WithModifier &s) {
+ auto v = enum_cast<Linear::Modifier::Type>(s.modifier.v);
+ return Linear{{Linear::Modifier{v},
+ makeList(s.names, makeObject(semaCtx)),
+ maybeApply(makeExpr(semaCtx), s.step)}};
+ },
+ [&](const wrapped::WithoutModifier &s) {
+ return Linear{{std::nullopt, makeList(s.names, makeObject(semaCtx)),
+ maybeApply(makeExpr(semaCtx), s.step)}};
+ },
+ },
+ inp.v.u);
+}
+
+struct Link {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+Link make(const parser::OmpClause::Link &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpObjectList
+ return Link{makeList(inp.v, semaCtx)};
+}
+
+struct Map {
+ struct MapType {
+ struct Always {
+ using EmptyTrait = std::true_type;
+ };
+ ENUM_CLASS(Type, To, From, Tofrom, Alloc, Release, Delete)
+ using TupleTrait = std::true_type;
+ std::tuple<std::optional<Always>, Type> t;
+ };
+ using TupleTrait = std::true_type;
+ std::tuple<std::optional<MapType>, ObjectList> t;
+};
+
+Map make(const parser::OmpClause::Map &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpMapClause
+ auto &t0 = std::get<std::optional<parser::OmpMapType>>(inp.v.t);
+ auto &t1 = std::get<parser::OmpObjectList>(inp.v.t);
+ auto convert = [](const parser::OmpMapType &s) {
+ auto &s0 = std::get<std::optional<parser::OmpMapType::Always>>(s.t);
+ auto &s1 = std::get<parser::OmpMapType::Type>(s.t);
+ auto convertT = [](parser::OmpMapType::Always) {
+ return Map::MapType::Always{};
+ };
+ return Map::MapType{
+ {maybeApply(convertT, s0), enum_cast<Map::MapType::Type>(s1)}};
+ };
+ return Map{{maybeApply(convert, t0), makeList(t1, semaCtx)}};
+}
+
+struct Nocontext {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+Nocontext make(const parser::OmpClause::Nocontext &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ScalarLogicalExpr
+ return Nocontext{makeExpr(inp.v, semaCtx)};
+}
+
+struct Nontemporal {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+Nontemporal make(const parser::OmpClause::Nontemporal &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> std::list<parser::Name>
+ return Nontemporal{makeList(inp.v, makeObject(semaCtx))};
+}
+
+struct Novariants {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+Novariants make(const parser::OmpClause::Novariants &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ScalarLogicalExpr
+ return Novariants{makeExpr(inp.v, semaCtx)};
+}
+
+struct NumTasks {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+NumTasks make(const parser::OmpClause::NumTasks &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ScalarIntExpr
+ return NumTasks{makeExpr(inp.v, semaCtx)};
+}
+
+struct NumTeams {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+NumTeams make(const parser::OmpClause::NumTeams &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ScalarIntExpr
+ return NumTeams{makeExpr(inp.v, semaCtx)};
+}
+
+struct NumThreads {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+NumThreads make(const parser::OmpClause::NumThreads &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ScalarIntExpr
+ return NumThreads{makeExpr(inp.v, semaCtx)};
+}
+
+struct OmpxDynCgroupMem {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+OmpxDynCgroupMem make(const parser::OmpClause::OmpxDynCgroupMem &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ScalarIntExpr
+ return OmpxDynCgroupMem{makeExpr(inp.v, semaCtx)};
+}
+
+struct Ordered {
+ using WrapperTrait = std::true_type;
+ MaybeExpr v;
+};
+
+Ordered make(const parser::OmpClause::Ordered &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> std::optional<parser::ScalarIntConstantExpr>
+ return Ordered{maybeApply(makeExpr(semaCtx), inp.v)};
+}
+
+struct Order {
+ ENUM_CLASS(Kind, Reproducible, Unconstrained)
+ ENUM_CLASS(Type, Concurrent)
+ using TupleTrait = std::true_type;
+ std::tuple<std::optional<Kind>, Type> t;
+};
+
+Order make(const parser::OmpClause::Order &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpOrderClause
+ using wrapped = parser::OmpOrderClause;
+ auto &t0 = std::get<std::optional<parser::OmpOrderModifier>>(inp.v.t);
+ auto &t1 = std::get<wrapped::Type>(inp.v.t);
+ auto convert = [](const parser::OmpOrderModifier &s) -> Order::Kind {
+ return enum_cast<Order::Kind>(
+ std::get<parser::OmpOrderModifier::Kind>(s.u));
+ };
+ return Order{{maybeApply(convert, t0), enum_cast<Order::Type>(t1)}};
+}
+
+struct Partial {
+ using WrapperTrait = std::true_type;
+ MaybeExpr v;
+};
+
+Partial make(const parser::OmpClause::Partial &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> std::optional<parser::ScalarIntConstantExpr>
+ return Partial{maybeApply(makeExpr(semaCtx), inp.v)};
+}
+
+struct Priority {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+Priority make(const parser::OmpClause::Priority &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ScalarIntExpr
+ return Priority{makeExpr(inp.v, semaCtx)};
+}
+
+struct Private {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+Private make(const parser::OmpClause::Private &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpObjectList
+ return Private{makeList(inp.v, semaCtx)};
+}
+
+struct ProcBind {
+ ENUM_CLASS(Type, Close, Master, Spread, Primary)
+ using WrapperTrait = std::true_type;
+ Type v;
+};
+
+ProcBind make(const parser::OmpClause::ProcBind &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpProcBindClause
+ return ProcBind{enum_cast<ProcBind::Type>(inp.v.v)};
+}
+
+struct Reduction {
+ using TupleTrait = std::true_type;
+ std::tuple<ReductionOperator, ObjectList> t;
+};
+
+Reduction make(const parser::OmpClause::Reduction &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpReductionClause
+ auto &t0 = std::get<parser::OmpReductionOperator>(inp.v.t);
+ auto &t1 = std::get<parser::OmpObjectList>(inp.v.t);
+ return Reduction{{makeRedOp(t0, semaCtx), makeList(t1, semaCtx)}};
+}
+
+struct Safelen {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+Safelen make(const parser::OmpClause::Safelen &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ScalarIntConstantExpr
+ return Safelen{makeExpr(inp.v, semaCtx)};
+}
+
+struct Schedule {
+ ENUM_CLASS(ModType, Monotonic, Nonmonotonic, Simd)
+ struct ScheduleModifier {
+ using TupleTrait = std::true_type;
+ std::tuple<ModType, std::optional<ModType>> t;
+ };
+ ENUM_CLASS(ScheduleType, Static, Dynamic, Guided, Auto, Runtime)
+ using TupleTrait = std::true_type;
+ std::tuple<std::optional<ScheduleModifier>, ScheduleType, MaybeExpr> t;
+};
+
+Schedule make(const parser::OmpClause::Schedule &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpScheduleClause
+ using wrapped = parser::OmpScheduleClause;
+
+ auto &t0 = std::get<std::optional<parser::OmpScheduleModifier>>(inp.v.t);
+ auto &t1 = std::get<wrapped::ScheduleType>(inp.v.t);
+ auto &t2 = std::get<std::optional<parser::ScalarIntExpr>>(inp.v.t);
+
+ auto convert = [](auto &&s) -> Schedule::ScheduleModifier {
+ auto &s0 = std::get<parser::OmpScheduleModifier::Modifier1>(s.t);
+ auto &s1 =
+ std::get<std::optional<parser::OmpScheduleModifier::Modifier2>>(s.t);
+
+ auto convert1 = [](auto &&v) { // Modifier1 or Modifier2
+ return enum_cast<Schedule::ModType>(v.v.v);
+ };
+ return Schedule::ScheduleModifier{{convert1(s0), maybeApply(convert1, s1)}};
+ };
+
+ return Schedule{{maybeApply(convert, t0),
+ enum_cast<Schedule::ScheduleType>(t1),
+ maybeApply(makeExpr(semaCtx), t2)}};
+}
+
+struct Shared {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+Shared make(const parser::OmpClause::Shared &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpObjectList
+ return Shared{makeList(inp.v, semaCtx)};
+}
+
+struct Simdlen {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+Simdlen make(const parser::OmpClause::Simdlen &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ScalarIntConstantExpr
+ return Simdlen{makeExpr(inp.v, semaCtx)};
+}
+
+struct Sizes {
+ using WrapperTrait = std::true_type;
+ List<SomeExpr> v;
+};
+
+Sizes make(const parser::OmpClause::Sizes &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> std::list<parser::ScalarIntExpr>
+ return Sizes{makeList(inp.v, makeExpr(semaCtx))};
+}
+
+struct TaskReduction {
+ using TupleTrait = std::true_type;
+ std::tuple<ReductionOperator, ObjectList> t;
+};
+
+TaskReduction make(const parser::OmpClause::TaskReduction &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpReductionClause
+ auto &t0 = std::get<parser::OmpReductionOperator>(inp.v.t);
+ auto &t1 = std::get<parser::OmpObjectList>(inp.v.t);
+ return TaskReduction{{makeRedOp(t0, semaCtx), makeList(t1, semaCtx)}};
+}
+
+struct ThreadLimit {
+ using WrapperTrait = std::true_type;
+ SomeExpr v;
+};
+
+ThreadLimit make(const parser::OmpClause::ThreadLimit &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::ScalarIntExpr
+ return ThreadLimit{makeExpr(inp.v, semaCtx)};
+}
+
+struct To {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+To make(const parser::OmpClause::To &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpObjectList
+ return To{makeList(inp.v, semaCtx)};
+}
+
+struct Uniform {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+Uniform make(const parser::OmpClause::Uniform &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> std::list<parser::Name>
+ return Uniform{makeList(inp.v, makeObject(semaCtx))};
+}
+
+struct UseDeviceAddr {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+UseDeviceAddr make(const parser::OmpClause::UseDeviceAddr &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpObjectList
+ return UseDeviceAddr{makeList(inp.v, semaCtx)};
+}
+
+struct UseDevicePtr {
+ using WrapperTrait = std::true_type;
+ ObjectList v;
+};
+
+UseDevicePtr make(const parser::OmpClause::UseDevicePtr &inp,
+ semantics::SemanticsContext &semaCtx) {
+ // inp.v -> parser::OmpObjectList
+ return UseDevicePtr{makeList(inp.v, semaCtx)};
+}
+
+using UnionOfAllClauses = std::variant<
+ AcqRel, Acquire, AdjustArgs, Affinity, Align, Aligned, Allocate, Allocator,
+ AppendArgs, At, AtomicDefaultMemOrder, Bind, CancellationConstructType,
+ Capture, Collapse, Compare, Copyprivate, Copyin, Default, Defaultmap,
+ Depend, Depobj, Destroy, Detach, Device, DeviceType, DistSchedule, Doacross,
+ DynamicAllocators, Enter, Exclusive, Fail, Filter, Final, Firstprivate,
+ Flush, From, Full, Grainsize, HasDeviceAddr, Hint, If, InReduction,
+ Inbranch, Inclusive, Indirect, Init, IsDevicePtr, Lastprivate, Linear, Link,
+ Map, Match, MemoryOrder, Mergeable, Message, Nogroup, Nowait, Nocontext,
+ Nontemporal, Notinbranch, Novariants, NumTasks, NumTeams, NumThreads,
+ OmpxAttribute, OmpxDynCgroupMem, OmpxBare, Order, Ordered, Partial,
+ Priority, Private, ProcBind, Read, Reduction, Relaxed, Release,
+ ReverseOffload, Safelen, Schedule, SeqCst, Severity, Shared, Simd, Simdlen,
+ Sizes, TaskReduction, ThreadLimit, Threadprivate, Threads, To,
+ UnifiedAddress, UnifiedSharedMemory, Uniform, Unknown, Untied, Update, Use,
+ UseDeviceAddr, UseDevicePtr, UsesAllocators, Weak, When, Write>;
+
+} // namespace clause
+
+struct Clause {
+ parser::CharBlock source;
+ llvm::omp::Clause id; // The numeric id of the clause
+ using UnionTrait = std::true_type;
+ clause::UnionOfAllClauses u;
+};
+
+Clause makeClause(const Fortran::parser::OmpClause &cls,
+ semantics::SemanticsContext &semaCtx) {
+ return std::visit(
+ [&](auto &&s) {
+ return Clause{cls.source, getClauseId(cls), clause::make(s, semaCtx)};
+ },
+ cls.u);
+}
+
+List<Clause> makeList(const parser::OmpClauseList &clauses,
+ semantics::SemanticsContext &semaCtx) {
+ return makeList(clauses.v, [&](const parser::OmpClause &s) {
+ return makeClause(s, semaCtx);
+ });
+}
+} // namespace omp
+
//===----------------------------------------------------------------------===//
// DataSharingProcessor
//===----------------------------------------------------------------------===//
>From e6ed53d812a752d33428d9f6d03b3be15b575db9 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Tue, 6 Feb 2024 17:06:29 -0600
Subject: [PATCH 2/6] [flang][OpenMP] Convert unique clauses in ClauseProcessor
Temporarily rename old clause list to `clauses2`, old clause iterator
to `ClauseIterator2`.
Change `findUniqueClause` to iterate over `omp::Clause` objects,
modify all handlers to operate on 'omp::clause::xyz` equivalents.
---
flang/lib/Lower/OpenMP.cpp | 242 +++++++++++++++++--------------------
1 file changed, 114 insertions(+), 128 deletions(-)
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 5ddc1eaa27e003..09797ca53c911e 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1670,7 +1670,8 @@ class ClauseProcessor {
ClauseProcessor(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
const Fortran::parser::OmpClauseList &clauses)
- : converter(converter), semaCtx(semaCtx), clauses(clauses) {}
+ : converter(converter), semaCtx(semaCtx), clauses2(clauses),
+ clauses(omp::makeList(clauses, semaCtx)) {}
// 'Unique' clauses: They can appear at most once in the clause list.
bool
@@ -1770,7 +1771,8 @@ class ClauseProcessor {
llvm::omp::Directive directive) const;
private:
- using ClauseIterator = std::list<ClauseTy>::const_iterator;
+ using ClauseIterator = omp::List<omp::Clause>::const_iterator;
+ using ClauseIterator2 = std::list<ClauseTy>::const_iterator;
/// Utility to find a clause within a range in the clause list.
template <typename T>
@@ -1783,14 +1785,26 @@ class ClauseProcessor {
return end;
}
+ /// Utility to find a clause within a range in the clause list.
+ template <typename T>
+ static ClauseIterator2 findClause2(ClauseIterator2 begin,
+ ClauseIterator2 end) {
+ for (ClauseIterator2 it = begin; it != end; ++it) {
+ if (std::get_if<T>(&it->u))
+ return it;
+ }
+
+ return end;
+ }
+
/// Return the first instance of the given clause found in the clause list or
/// `nullptr` if not present. If more than one instance is expected, use
/// `findRepeatableClause` instead.
template <typename T>
const T *
findUniqueClause(const Fortran::parser::CharBlock **source = nullptr) const {
- ClauseIterator it = findClause<T>(clauses.v.begin(), clauses.v.end());
- if (it != clauses.v.end()) {
+ ClauseIterator it = findClause<T>(clauses.begin(), clauses.end());
+ if (it != clauses.end()) {
if (source)
*source = &it->source;
return &std::get<T>(it->u);
@@ -1805,9 +1819,9 @@ class ClauseProcessor {
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
callbackFn) const {
bool found = false;
- ClauseIterator nextIt, endIt = clauses.v.end();
- for (ClauseIterator it = clauses.v.begin(); it != endIt; it = nextIt) {
- nextIt = findClause<T>(it, endIt);
+ ClauseIterator2 nextIt, endIt = clauses2.v.end();
+ for (ClauseIterator2 it = clauses2.v.begin(); it != endIt; it = nextIt) {
+ nextIt = findClause2<T>(it, endIt);
if (nextIt != endIt) {
callbackFn(&std::get<T>(nextIt->u), nextIt->source);
@@ -1830,7 +1844,8 @@ class ClauseProcessor {
Fortran::lower::AbstractConverter &converter;
Fortran::semantics::SemanticsContext &semaCtx;
- const Fortran::parser::OmpClauseList &clauses;
+ const Fortran::parser::OmpClauseList &clauses2;
+ omp::List<omp::Clause> clauses;
};
//===----------------------------------------------------------------------===//
@@ -2295,64 +2310,55 @@ class ReductionProcessor {
};
static mlir::omp::ScheduleModifier
-translateScheduleModifier(const Fortran::parser::OmpScheduleModifierType &m) {
- switch (m.v) {
- case Fortran::parser::OmpScheduleModifierType::ModType::Monotonic:
+translateScheduleModifier(const omp::clause::Schedule::ModType &m) {
+ switch (m) {
+ case omp::clause::Schedule::ModType::Monotonic:
return mlir::omp::ScheduleModifier::monotonic;
- case Fortran::parser::OmpScheduleModifierType::ModType::Nonmonotonic:
+ case omp::clause::Schedule::ModType::Nonmonotonic:
return mlir::omp::ScheduleModifier::nonmonotonic;
- case Fortran::parser::OmpScheduleModifierType::ModType::Simd:
+ case omp::clause::Schedule::ModType::Simd:
return mlir::omp::ScheduleModifier::simd;
}
return mlir::omp::ScheduleModifier::none;
}
static mlir::omp::ScheduleModifier
-getScheduleModifier(const Fortran::parser::OmpScheduleClause &x) {
- const auto &modifier =
- std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t);
+getScheduleModifier(const omp::clause::Schedule &clause) {
+ using ScheduleModifier = omp::clause::Schedule::ScheduleModifier;
+ const auto &modifier = std::get<std::optional<ScheduleModifier>>(clause.t);
// The input may have the modifier any order, so we look for one that isn't
// SIMD. If modifier is not set at all, fall down to the bottom and return
// "none".
if (modifier) {
- const auto &modType1 =
- std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t);
- if (modType1.v.v ==
- Fortran::parser::OmpScheduleModifierType::ModType::Simd) {
- const auto &modType2 = std::get<
- std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>(
- modifier->t);
- if (modType2 &&
- modType2->v.v !=
- Fortran::parser::OmpScheduleModifierType::ModType::Simd)
- return translateScheduleModifier(modType2->v);
-
+ using ModType = omp::clause::Schedule::ModType;
+ const auto &modType1 = std::get<ModType>(modifier->t);
+ if (modType1 == ModType::Simd) {
+ const auto &modType2 = std::get<std::optional<ModType>>(modifier->t);
+ if (modType2 && *modType2 != ModType::Simd)
+ return translateScheduleModifier(*modType2);
return mlir::omp::ScheduleModifier::none;
}
- return translateScheduleModifier(modType1.v);
+ return translateScheduleModifier(modType1);
}
return mlir::omp::ScheduleModifier::none;
}
static mlir::omp::ScheduleModifier
-getSimdModifier(const Fortran::parser::OmpScheduleClause &x) {
- const auto &modifier =
- std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t);
+getSimdModifier(const omp::clause::Schedule &clause) {
+ using ScheduleModifier = omp::clause::Schedule::ScheduleModifier;
+ const auto &modifier = std::get<std::optional<ScheduleModifier>>(clause.t);
// Either of the two possible modifiers in the input can be the SIMD modifier,
// so look in either one, and return simd if we find one. Not found = return
// "none".
if (modifier) {
- const auto &modType1 =
- std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t);
- if (modType1.v.v == Fortran::parser::OmpScheduleModifierType::ModType::Simd)
+ using ModType = omp::clause::Schedule::ModType;
+ const auto &modType1 = std::get<ModType>(modifier->t);
+ if (modType1 == ModType::Simd)
return mlir::omp::ScheduleModifier::simd;
- const auto &modType2 = std::get<
- std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>(
- modifier->t);
- if (modType2 && modType2->v.v ==
- Fortran::parser::OmpScheduleModifierType::ModType::Simd)
+ const auto &modType2 = std::get<std::optional<ModType>>(modifier->t);
+ if (modType2 && *modType2 == ModType::Simd)
return mlir::omp::ScheduleModifier::simd;
}
return mlir::omp::ScheduleModifier::none;
@@ -2406,21 +2412,21 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter,
genObjectList(ompObjectList, converter, allocateOperands);
}
-static mlir::omp::ClauseProcBindKindAttr genProcBindKindAttr(
- fir::FirOpBuilder &firOpBuilder,
- const Fortran::parser::OmpClause::ProcBind *procBindClause) {
+static mlir::omp::ClauseProcBindKindAttr
+genProcBindKindAttr(fir::FirOpBuilder &firOpBuilder,
+ const omp::clause::ProcBind &clause) {
mlir::omp::ClauseProcBindKind procBindKind;
- switch (procBindClause->v.v) {
- case Fortran::parser::OmpProcBindClause::Type::Master:
+ switch (clause.v) {
+ case omp::clause::ProcBind::Type::Master:
procBindKind = mlir::omp::ClauseProcBindKind::Master;
break;
- case Fortran::parser::OmpProcBindClause::Type::Close:
+ case omp::clause::ProcBind::Type::Close:
procBindKind = mlir::omp::ClauseProcBindKind::Close;
break;
- case Fortran::parser::OmpProcBindClause::Type::Spread:
+ case omp::clause::ProcBind::Type::Spread:
procBindKind = mlir::omp::ClauseProcBindKind::Spread;
break;
- case Fortran::parser::OmpProcBindClause::Type::Primary:
+ case omp::clause::ProcBind::Type::Primary:
procBindKind = mlir::omp::ClauseProcBindKind::Primary;
break;
}
@@ -2518,9 +2524,8 @@ bool ClauseProcessor::processCollapse(
}
std::int64_t collapseValue = 1l;
- if (auto *collapseClause = findUniqueClause<ClauseTy::Collapse>()) {
- const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
- collapseValue = Fortran::evaluate::ToInt64(*expr).value();
+ if (auto *clause = findUniqueClause<omp::clause::Collapse>()) {
+ collapseValue = Fortran::evaluate::ToInt64(clause->v).value();
found = true;
}
@@ -2559,19 +2564,19 @@ bool ClauseProcessor::processCollapse(
}
bool ClauseProcessor::processDefault() const {
- if (auto *defaultClause = findUniqueClause<ClauseTy::Default>()) {
+ if (auto *clause = findUniqueClause<omp::clause::Default>()) {
// Private, Firstprivate, Shared, None
- switch (defaultClause->v.v) {
- case Fortran::parser::OmpDefaultClause::Type::Shared:
- case Fortran::parser::OmpDefaultClause::Type::None:
+ switch (clause->v) {
+ case omp::clause::Default::Type::Shared:
+ case omp::clause::Default::Type::None:
// Default clause with shared or none do not require any handling since
// Shared is the default behavior in the IR and None is only required
// for semantic checks.
break;
- case Fortran::parser::OmpDefaultClause::Type::Private:
+ case omp::clause::Default::Type::Private:
// TODO Support default(private)
break;
- case Fortran::parser::OmpDefaultClause::Type::Firstprivate:
+ case omp::clause::Default::Type::Firstprivate:
// TODO Support default(firstprivate)
break;
}
@@ -2583,20 +2588,17 @@ bool ClauseProcessor::processDefault() const {
bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const {
const Fortran::parser::CharBlock *source = nullptr;
- if (auto *deviceClause = findUniqueClause<ClauseTy::Device>(&source)) {
+ if (auto *clause = findUniqueClause<omp::clause::Device>(&source)) {
mlir::Location clauseLocation = converter.genLocation(*source);
- if (auto deviceModifier = std::get<
- std::optional<Fortran::parser::OmpDeviceClause::DeviceModifier>>(
- deviceClause->v.t)) {
- if (deviceModifier ==
- Fortran::parser::OmpDeviceClause::DeviceModifier::Ancestor) {
+ if (auto deviceModifier =
+ std::get<std::optional<omp::clause::Device::DeviceModifier>>(
+ clause->t)) {
+ if (deviceModifier == omp::clause::Device::DeviceModifier::Ancestor) {
TODO(clauseLocation, "OMPD_target Device Modifier Ancestor");
}
}
- if (const auto *deviceExpr = Fortran::semantics::GetExpr(
- std::get<Fortran::parser::ScalarIntExpr>(deviceClause->v.t))) {
- result = fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx));
- }
+ const auto &deviceExpr = std::get<omp::SomeExpr>(clause->t);
+ result = fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
return true;
}
return false;
@@ -2604,16 +2606,16 @@ bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
bool ClauseProcessor::processDeviceType(
mlir::omp::DeclareTargetDeviceType &result) const {
- if (auto *deviceTypeClause = findUniqueClause<ClauseTy::DeviceType>()) {
+ if (auto *clause = findUniqueClause<omp::clause::DeviceType>()) {
// Case: declare target ... device_type(any | host | nohost)
- switch (deviceTypeClause->v.v) {
- case Fortran::parser::OmpDeviceTypeClause::Type::Nohost:
+ switch (clause->v) {
+ case omp::clause::DeviceType::Type::Nohost:
result = mlir::omp::DeclareTargetDeviceType::nohost;
break;
- case Fortran::parser::OmpDeviceTypeClause::Type::Host:
+ case omp::clause::DeviceType::Type::Host:
result = mlir::omp::DeclareTargetDeviceType::host;
break;
- case Fortran::parser::OmpDeviceTypeClause::Type::Any:
+ case omp::clause::DeviceType::Type::Any:
result = mlir::omp::DeclareTargetDeviceType::any;
break;
}
@@ -2625,12 +2627,12 @@ bool ClauseProcessor::processDeviceType(
bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const {
const Fortran::parser::CharBlock *source = nullptr;
- if (auto *finalClause = findUniqueClause<ClauseTy::Final>(&source)) {
+ if (auto *clause = findUniqueClause<omp::clause::Final>(&source)) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location clauseLocation = converter.genLocation(*source);
- mlir::Value finalVal = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(finalClause->v), stmtCtx));
+ mlir::Value finalVal =
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx));
result = firOpBuilder.createConvert(clauseLocation,
firOpBuilder.getI1Type(), finalVal);
return true;
@@ -2639,10 +2641,9 @@ bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
}
bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const {
- if (auto *hintClause = findUniqueClause<ClauseTy::Hint>()) {
+ if (auto *clause = findUniqueClause<omp::clause::Hint>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- const auto *expr = Fortran::semantics::GetExpr(hintClause->v);
- int64_t hintValue = *Fortran::evaluate::ToInt64(*expr);
+ int64_t hintValue = *Fortran::evaluate::ToInt64(clause->v);
result = firOpBuilder.getI64IntegerAttr(hintValue);
return true;
}
@@ -2650,20 +2651,19 @@ bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const {
}
bool ClauseProcessor::processMergeable(mlir::UnitAttr &result) const {
- return markClauseOccurrence<ClauseTy::Mergeable>(result);
+ return markClauseOccurrence<omp::clause::Mergeable>(result);
}
bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const {
- return markClauseOccurrence<ClauseTy::Nowait>(result);
+ return markClauseOccurrence<omp::clause::Nowait>(result);
}
bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const {
// TODO Get lower and upper bounds for num_teams when parser is updated to
// accept both.
- if (auto *numTeamsClause = findUniqueClause<ClauseTy::NumTeams>()) {
- result = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(numTeamsClause->v), stmtCtx));
+ if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) {
+ result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
@@ -2671,22 +2671,20 @@ bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx,
bool ClauseProcessor::processNumThreads(
Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
- if (auto *numThreadsClause = findUniqueClause<ClauseTy::NumThreads>()) {
+ if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
// OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
- result = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
+ result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const {
- if (auto *orderedClause = findUniqueClause<ClauseTy::Ordered>()) {
+ if (auto *clause = findUniqueClause<omp::clause::Ordered>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
int64_t orderedClauseValue = 0l;
- if (orderedClause->v.has_value()) {
- const auto *expr = Fortran::semantics::GetExpr(orderedClause->v);
- orderedClauseValue = *Fortran::evaluate::ToInt64(*expr);
+ if (clause->v.has_value()) {
+ orderedClauseValue = *Fortran::evaluate::ToInt64(*clause->v);
}
result = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
return true;
@@ -2696,9 +2694,8 @@ bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const {
bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const {
- if (auto *priorityClause = findUniqueClause<ClauseTy::Priority>()) {
- result = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx));
+ if (auto *clause = findUniqueClause<omp::clause::Priority>()) {
+ result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
@@ -2706,20 +2703,19 @@ bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx,
bool ClauseProcessor::processProcBind(
mlir::omp::ClauseProcBindKindAttr &result) const {
- if (auto *procBindClause = findUniqueClause<ClauseTy::ProcBind>()) {
+ if (auto *clause = findUniqueClause<omp::clause::ProcBind>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- result = genProcBindKindAttr(firOpBuilder, procBindClause);
+ result = genProcBindKindAttr(firOpBuilder, *clause);
return true;
}
return false;
}
bool ClauseProcessor::processSafelen(mlir::IntegerAttr &result) const {
- if (auto *safelenClause = findUniqueClause<ClauseTy::Safelen>()) {
+ if (auto *clause = findUniqueClause<omp::clause::Safelen>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- const auto *expr = Fortran::semantics::GetExpr(safelenClause->v);
const std::optional<std::int64_t> safelenVal =
- Fortran::evaluate::ToInt64(*expr);
+ Fortran::evaluate::ToInt64(clause->v);
result = firOpBuilder.getI64IntegerAttr(*safelenVal);
return true;
}
@@ -2730,41 +2726,38 @@ bool ClauseProcessor::processSchedule(
mlir::omp::ClauseScheduleKindAttr &valAttr,
mlir::omp::ScheduleModifierAttr &modifierAttr,
mlir::UnitAttr &simdModifierAttr) const {
- if (auto *scheduleClause = findUniqueClause<ClauseTy::Schedule>()) {
+ if (auto *clause = findUniqueClause<omp::clause::Schedule>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::MLIRContext *context = firOpBuilder.getContext();
- const Fortran::parser::OmpScheduleClause &scheduleType = scheduleClause->v;
- const auto &scheduleClauseKind =
- std::get<Fortran::parser::OmpScheduleClause::ScheduleType>(
- scheduleType.t);
+ const auto &scheduleType =
+ std::get<omp::clause::Schedule::ScheduleType>(clause->t);
mlir::omp::ClauseScheduleKind scheduleKind;
- switch (scheduleClauseKind) {
- case Fortran::parser::OmpScheduleClause::ScheduleType::Static:
+ switch (scheduleType) {
+ case omp::clause::Schedule::ScheduleType::Static:
scheduleKind = mlir::omp::ClauseScheduleKind::Static;
break;
- case Fortran::parser::OmpScheduleClause::ScheduleType::Dynamic:
+ case omp::clause::Schedule::ScheduleType::Dynamic:
scheduleKind = mlir::omp::ClauseScheduleKind::Dynamic;
break;
- case Fortran::parser::OmpScheduleClause::ScheduleType::Guided:
+ case omp::clause::Schedule::ScheduleType::Guided:
scheduleKind = mlir::omp::ClauseScheduleKind::Guided;
break;
- case Fortran::parser::OmpScheduleClause::ScheduleType::Auto:
+ case omp::clause::Schedule::ScheduleType::Auto:
scheduleKind = mlir::omp::ClauseScheduleKind::Auto;
break;
- case Fortran::parser::OmpScheduleClause::ScheduleType::Runtime:
+ case omp::clause::Schedule::ScheduleType::Runtime:
scheduleKind = mlir::omp::ClauseScheduleKind::Runtime;
break;
}
- mlir::omp::ScheduleModifier scheduleModifier =
- getScheduleModifier(scheduleClause->v);
+ mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause);
if (scheduleModifier != mlir::omp::ScheduleModifier::none)
modifierAttr =
mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier);
- if (getSimdModifier(scheduleClause->v) != mlir::omp::ScheduleModifier::none)
+ if (getSimdModifier(*clause) != mlir::omp::ScheduleModifier::none)
simdModifierAttr = firOpBuilder.getUnitAttr();
valAttr = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
@@ -2775,25 +2768,19 @@ bool ClauseProcessor::processSchedule(
bool ClauseProcessor::processScheduleChunk(
Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
- if (auto *scheduleClause = findUniqueClause<ClauseTy::Schedule>()) {
- if (const auto &chunkExpr =
- std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
- scheduleClause->v.t)) {
- if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) {
- result = fir::getBase(converter.genExprValue(*expr, stmtCtx));
- }
- }
+ if (auto *clause = findUniqueClause<omp::clause::Schedule>()) {
+ if (const auto &chunkExpr = std::get<omp::MaybeExpr>(clause->t))
+ result = fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const {
- if (auto *simdlenClause = findUniqueClause<ClauseTy::Simdlen>()) {
+ if (auto *clause = findUniqueClause<omp::clause::Simdlen>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- const auto *expr = Fortran::semantics::GetExpr(simdlenClause->v);
const std::optional<std::int64_t> simdlenVal =
- Fortran::evaluate::ToInt64(*expr);
+ Fortran::evaluate::ToInt64(clause->v);
result = firOpBuilder.getI64IntegerAttr(*simdlenVal);
return true;
}
@@ -2802,16 +2789,15 @@ bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const {
bool ClauseProcessor::processThreadLimit(
Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
- if (auto *threadLmtClause = findUniqueClause<ClauseTy::ThreadLimit>()) {
- result = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx));
+ if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) {
+ result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
- return markClauseOccurrence<ClauseTy::Untied>(result);
+ return markClauseOccurrence<omp::clause::Untied>(result);
}
//===----------------------------------------------------------------------===//
@@ -3205,7 +3191,7 @@ void ClauseProcessor::processTODO(mlir::Location currentLocation,
" construct");
};
- for (ClauseIterator it = clauses.v.begin(); it != clauses.v.end(); ++it)
+ for (ClauseIterator2 it = clauses2.v.begin(); it != clauses2.v.end(); ++it)
(checkUnhandledClause(std::get_if<Ts>(&it->u)), ...);
}
>From 3ba0e28a55cb54e4cfe983415f763583084d9a9b Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Tue, 6 Feb 2024 17:06:29 -0600
Subject: [PATCH 3/6] [flang][OpenMP] Convert repeatable clauses (except Map)
in ClauseProcessor
Rename `findRepeatableClause` to `findRepeatableClause2`, and make the
new `findRepeatableClause` operate on new `omp::Clause` objects.
Leave `Map` unchanged, because it will require more changes for it to
work.
---
flang/include/flang/Evaluate/tools.h | 23 +
flang/lib/Lower/OpenMP.cpp | 632 +++++++++++++--------------
2 files changed, 328 insertions(+), 327 deletions(-)
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index d257da1a709642..e9999974944e88 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -430,6 +430,29 @@ template <typename A> std::optional<CoarrayRef> ExtractCoarrayRef(const A &x) {
}
}
+struct ExtractSubstringHelper {
+ template <typename T> static std::optional<Substring> visit(T &&) {
+ return std::nullopt;
+ }
+
+ static std::optional<Substring> visit(const Substring &e) { return e; }
+
+ template <typename T>
+ static std::optional<Substring> visit(const Designator<T> &e) {
+ return std::visit([](auto &&s) { return visit(s); }, e.u);
+ }
+
+ template <typename T>
+ static std::optional<Substring> visit(const Expr<T> &e) {
+ return std::visit([](auto &&s) { return visit(s); }, e.u);
+ }
+};
+
+template <typename A>
+std::optional<Substring> ExtractSubstring(const A &x) {
+ return ExtractSubstringHelper::visit(x);
+}
+
// If an expression is simply a whole symbol data designator,
// extract and return that symbol, else null.
template <typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) {
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 09797ca53c911e..921649b868040c 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -72,9 +72,9 @@ getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
return sym;
}
-static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
- Fortran::lower::AbstractConverter &converter,
- llvm::SmallVectorImpl<mlir::Value> &operands) {
+static void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
+ Fortran::lower::AbstractConverter &converter,
+ llvm::SmallVectorImpl<mlir::Value> &operands) {
auto addOperands = [&](Fortran::lower::SymbolRef sym) {
const mlir::Value variable = converter.getSymbolAddress(sym);
if (variable) {
@@ -93,27 +93,6 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
}
}
-static void gatherFuncAndVarSyms(
- const Fortran::parser::OmpObjectList &objList,
- mlir::omp::DeclareTargetCaptureClause clause,
- llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
- for (const Fortran::parser::OmpObject &ompObject : objList.v) {
- Fortran::common::visit(
- Fortran::common::visitors{
- [&](const Fortran::parser::Designator &designator) {
- if (const Fortran::parser::Name *name =
- Fortran::semantics::getDesignatorNameIfDataRef(
- designator)) {
- symbolAndClause.emplace_back(clause, *name->symbol);
- }
- },
- [&](const Fortran::parser::Name &name) {
- symbolAndClause.emplace_back(clause, *name.symbol);
- }},
- ompObject.u);
- }
-}
-
static Fortran::lower::pft::Evaluation *
getCollapsedLoopEval(Fortran::lower::pft::Evaluation &eval, int collapseValue) {
// Return the Evaluation of the innermost collapsed loop, or the current one
@@ -1258,6 +1237,32 @@ List<Clause> makeList(const parser::OmpClauseList &clauses,
}
} // namespace omp
+static void genObjectList(const omp::ObjectList &objects,
+ Fortran::lower::AbstractConverter &converter,
+ llvm::SmallVectorImpl<mlir::Value> &operands) {
+ for (const omp::Object &object : objects) {
+ const Fortran::semantics::Symbol *sym = object.sym;
+ assert(sym && "Expected Symbol");
+ if (mlir::Value variable = converter.getSymbolAddress(*sym)) {
+ operands.push_back(variable);
+ } else {
+ if (const auto *details =
+ sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
+ operands.push_back(converter.getSymbolAddress(details->symbol()));
+ converter.copySymbolBinding(details->symbol(), *sym);
+ }
+ }
+ }
+}
+
+static void gatherFuncAndVarSyms(
+ const omp::ObjectList &objects,
+ mlir::omp::DeclareTargetCaptureClause clause,
+ llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
+ for (const omp::Object &object : objects)
+ symbolAndClause.emplace_back(clause, *object.sym);
+}
+
//===----------------------------------------------------------------------===//
// DataSharingProcessor
//===----------------------------------------------------------------------===//
@@ -1719,9 +1724,8 @@ class ClauseProcessor {
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
- bool
- processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
- mlir::Value &result) const;
+ bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
+ mlir::Value &result) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
@@ -1816,6 +1820,26 @@ class ClauseProcessor {
/// if at least one instance was found.
template <typename T>
bool findRepeatableClause(
+ std::function<void(const T &, const Fortran::parser::CharBlock &source)>
+ callbackFn) const {
+ bool found = false;
+ ClauseIterator nextIt, endIt = clauses.end();
+ for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
+ nextIt = findClause<T>(it, endIt);
+
+ if (nextIt != endIt) {
+ callbackFn(std::get<T>(nextIt->u), nextIt->source);
+ found = true;
+ ++nextIt;
+ }
+ }
+ return found;
+ }
+
+ /// Call `callbackFn` for each occurrence of the given clause. Return `true`
+ /// if at least one instance was found.
+ template <typename T>
+ bool findRepeatableClause2(
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
callbackFn) const {
bool found = false;
@@ -1881,9 +1905,9 @@ class ReductionProcessor {
IEOR
};
static ReductionIdentifier
- getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
+ getReductionType(const omp::clause::ProcedureDesignator &pd) {
auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
- getRealName(pd).ToString())
+ getRealName(pd.v.sym).ToString())
.Case("max", ReductionIdentifier::MAX)
.Case("min", ReductionIdentifier::MIN)
.Case("iand", ReductionIdentifier::IAND)
@@ -1895,35 +1919,33 @@ class ReductionProcessor {
}
static ReductionIdentifier getReductionType(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) {
+ omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Add:
return ReductionIdentifier::ADD;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
return ReductionIdentifier::SUBTRACT;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
return ReductionIdentifier::MULTIPLY;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ case omp::clause::DefinedOperator::IntrinsicOperator::AND:
return ReductionIdentifier::AND;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
return ReductionIdentifier::EQV;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ case omp::clause::DefinedOperator::IntrinsicOperator::OR:
return ReductionIdentifier::OR;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
return ReductionIdentifier::NEQV;
default:
llvm_unreachable("unexpected intrinsic operator in reduction");
}
}
- static bool supportedIntrinsicProcReduction(
- const Fortran::parser::ProcedureDesignator &pd) {
- const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
- assert(name && "Invalid Reduction Intrinsic.");
- if (!name->symbol->GetUltimate().attrs().test(
- Fortran::semantics::Attr::INTRINSIC))
+ static bool
+ supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd) {
+ Fortran::semantics::Symbol *sym = pd.v.sym;
+ if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC))
return false;
- auto redType = llvm::StringSwitch<bool>(getRealName(name).ToString())
+ auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
.Case("max", true)
.Case("min", true)
.Case("iand", true)
@@ -1934,15 +1956,13 @@ class ReductionProcessor {
}
static const Fortran::semantics::SourceName
- getRealName(const Fortran::parser::Name *name) {
- return name->symbol->GetUltimate().name();
+ getRealName(const Fortran::semantics::Symbol *symbol) {
+ return symbol->GetUltimate().name();
}
static const Fortran::semantics::SourceName
- getRealName(const Fortran::parser::ProcedureDesignator &pd) {
- const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
- assert(name && "Invalid Reduction Intrinsic.");
- return getRealName(name);
+ getRealName(const omp::clause::ProcedureDesignator &pd) {
+ return getRealName(pd.v.sym);
}
static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
@@ -1952,25 +1972,25 @@ class ReductionProcessor {
.str();
}
- static std::string getReductionName(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type ty) {
+ static std::string
+ getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
+ mlir::Type ty) {
std::string reductionName;
switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Add:
reductionName = "add_reduction";
break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
reductionName = "multiply_reduction";
break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ case omp::clause::DefinedOperator::IntrinsicOperator::AND:
return "and_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
return "eqv_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ case omp::clause::DefinedOperator::IntrinsicOperator::OR:
return "or_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
return "neqv_reduction";
default:
reductionName = "other_reduction";
@@ -2214,7 +2234,7 @@ class ReductionProcessor {
static void
addReductionDecl(mlir::Location currentLocation,
Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpReductionClause &reduction,
+ const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
@@ -2222,13 +2242,12 @@ class ReductionProcessor {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::omp::ReductionDeclareOp decl;
const auto &redOperator{
- std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
- const auto &objectList{
- std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+ std::get<omp::clause::ReductionOperator>(reduction.t)};
+ const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
if (const auto &redDefinedOp =
- std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+ std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
const auto &intrinsicOp{
- std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+ std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
redDefinedOp->u)};
ReductionIdentifier redId = getReductionType(intrinsicOp);
switch (redId) {
@@ -2244,10 +2263,41 @@ class ReductionProcessor {
"Reduction of some intrinsic operators is not supported");
break;
}
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ for (const omp::Object &object : objectList) {
+ if (const Fortran::semantics::Symbol *symbol = object.sym) {
+ if (reductionSymbols)
+ reductionSymbols->push_back(symbol);
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+ symVal = declOp.getBase();
+ mlir::Type redType =
+ symVal.getType().cast<fir::ReferenceType>().getEleTy();
+ reductionVars.push_back(symVal);
+ if (redType.isa<fir::LogicalType>())
+ decl = createReductionDecl(
+ firOpBuilder,
+ getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
+ redType, currentLocation);
+ else if (redType.isIntOrIndexOrFloat()) {
+ decl = createReductionDecl(firOpBuilder,
+ getReductionName(intrinsicOp, redType),
+ redId, redType, currentLocation);
+ } else {
+ TODO(currentLocation, "Reduction of some types is not supported");
+ }
+ reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+ firOpBuilder.getContext(), decl.getSymName()));
+ }
+ }
+ } else if (const auto *reductionIntrinsic =
+ std::get_if<omp::clause::ProcedureDesignator>(
+ &redOperator.u)) {
+ if (ReductionProcessor::supportedIntrinsicProcReduction(
+ *reductionIntrinsic)) {
+ ReductionProcessor::ReductionIdentifier redId =
+ ReductionProcessor::getReductionType(*reductionIntrinsic);
+ for (const omp::Object &object : objectList) {
+ if (const Fortran::semantics::Symbol *symbol = object.sym) {
if (reductionSymbols)
reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
@@ -2256,55 +2306,18 @@ class ReductionProcessor {
mlir::Type redType =
symVal.getType().cast<fir::ReferenceType>().getEleTy();
reductionVars.push_back(symVal);
- if (redType.isa<fir::LogicalType>())
- decl = createReductionDecl(
- firOpBuilder,
- getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
- redId, redType, currentLocation);
- else if (redType.isIntOrIndexOrFloat()) {
- decl = createReductionDecl(firOpBuilder,
- getReductionName(intrinsicOp, redType),
- redId, redType, currentLocation);
- } else {
- TODO(currentLocation, "Reduction of some types is not supported");
- }
+ assert(redType.isIntOrIndexOrFloat() &&
+ "Unsupported reduction type");
+ decl = createReductionDecl(
+ firOpBuilder,
+ getReductionName(getRealName(*reductionIntrinsic).ToString(),
+ redType),
+ redId, redType, currentLocation);
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
firOpBuilder.getContext(), decl.getSymName()));
}
}
}
- } else if (const auto *reductionIntrinsic =
- std::get_if<Fortran::parser::ProcedureDesignator>(
- &redOperator.u)) {
- if (ReductionProcessor::supportedIntrinsicProcReduction(
- *reductionIntrinsic)) {
- ReductionProcessor::ReductionIdentifier redId =
- ReductionProcessor::getReductionType(*reductionIntrinsic);
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- if (reductionSymbols)
- reductionSymbols->push_back(symbol);
- mlir::Value symVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
- symVal = declOp.getBase();
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- assert(redType.isIntOrIndexOrFloat() &&
- "Unsupported reduction type");
- decl = createReductionDecl(
- firOpBuilder,
- getReductionName(getRealName(*reductionIntrinsic).ToString(),
- redType),
- redId, redType, currentLocation);
- reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
- firOpBuilder.getContext(), decl.getSymName()));
- }
- }
- }
- }
}
}
};
@@ -2366,7 +2379,7 @@ getSimdModifier(const omp::clause::Schedule &clause) {
static void
genAllocateClause(Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpAllocateClause &ompAllocateClause,
+ const omp::clause::Allocate &clause,
llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
llvm::SmallVectorImpl<mlir::Value> &allocateOperands) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -2374,21 +2387,18 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext stmtCtx;
mlir::Value allocatorOperand;
- const Fortran::parser::OmpObjectList &ompObjectList =
- std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t);
- const auto &allocateModifier = std::get<
- std::optional<Fortran::parser::OmpAllocateClause::AllocateModifier>>(
- ompAllocateClause.t);
+ const omp::ObjectList &objectList = std::get<omp::ObjectList>(clause.t);
+ const auto &modifier =
+ std::get<std::optional<omp::clause::Allocate::Modifier>>(clause.t);
// If the allocate modifier is present, check if we only use the allocator
// submodifier. ALIGN in this context is unimplemented
const bool onlyAllocator =
- allocateModifier &&
- std::holds_alternative<
- Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
- allocateModifier->u);
+ modifier &&
+ std::holds_alternative<omp::clause::Allocate::Modifier::Allocator>(
+ modifier->u);
- if (allocateModifier && !onlyAllocator) {
+ if (modifier && !onlyAllocator) {
TODO(currentLocation, "OmpAllocateClause ALIGN modifier");
}
@@ -2396,20 +2406,17 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter,
// to list of allocators, otherwise, add default allocator to
// list of allocators.
if (onlyAllocator) {
- const auto &allocatorValue = std::get<
- Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
- allocateModifier->u);
- allocatorOperand = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(allocatorValue.v), stmtCtx));
- allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
- allocatorOperand);
+ const auto &value =
+ std::get<omp::clause::Allocate::Modifier::Allocator>(modifier->u);
+ mlir::Value operand =
+ fir::getBase(converter.genExprValue(value.v, stmtCtx));
+ allocatorOperands.append(objectList.size(), operand);
} else {
- allocatorOperand = firOpBuilder.createIntegerConstant(
+ mlir::Value operand = firOpBuilder.createIntegerConstant(
currentLocation, firOpBuilder.getI32Type(), 1);
- allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
- allocatorOperand);
+ allocatorOperands.append(objectList.size(), operand);
}
- genObjectList(ompObjectList, converter, allocateOperands);
+ genObjectList(objectList, converter, allocateOperands);
}
static mlir::omp::ClauseProcBindKindAttr
@@ -2436,20 +2443,17 @@ genProcBindKindAttr(fir::FirOpBuilder &firOpBuilder,
static mlir::omp::ClauseTaskDependAttr
genDependKindAttr(fir::FirOpBuilder &firOpBuilder,
- const Fortran::parser::OmpClause::Depend *dependClause) {
+ const omp::clause::Depend &clause) {
mlir::omp::ClauseTaskDepend pbKind;
- switch (
- std::get<Fortran::parser::OmpDependenceType>(
- std::get<Fortran::parser::OmpDependClause::InOut>(dependClause->v.u)
- .t)
- .v) {
- case Fortran::parser::OmpDependenceType::Type::In:
+ const auto &inOut = std::get<omp::clause::Depend::InOut>(clause.u);
+ switch (std::get<omp::clause::Depend::Type>(inOut.t)) {
+ case omp::clause::Depend::Type::In:
pbKind = mlir::omp::ClauseTaskDepend::taskdependin;
break;
- case Fortran::parser::OmpDependenceType::Type::Out:
+ case omp::clause::Depend::Type::Out:
pbKind = mlir::omp::ClauseTaskDepend::taskdependout;
break;
- case Fortran::parser::OmpDependenceType::Type::Inout:
+ case omp::clause::Depend::Type::Inout:
pbKind = mlir::omp::ClauseTaskDepend::taskdependinout;
break;
default:
@@ -2460,45 +2464,41 @@ genDependKindAttr(fir::FirOpBuilder &firOpBuilder,
pbKind);
}
-static mlir::Value getIfClauseOperand(
- Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpClause::If *ifClause,
- Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
- mlir::Location clauseLocation) {
+static mlir::Value
+getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
+ const omp::clause::If &clause,
+ omp::clause::If::DirectiveNameModifier directiveName,
+ mlir::Location clauseLocation) {
// Only consider the clause if it's intended for the given directive.
- auto &directive = std::get<
- std::optional<Fortran::parser::OmpIfClause::DirectiveNameModifier>>(
- ifClause->v.t);
+ auto &directive =
+ std::get<std::optional<omp::clause::If::DirectiveNameModifier>>(clause.t);
if (directive && directive.value() != directiveName)
return nullptr;
Fortran::lower::StatementContext stmtCtx;
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
mlir::Value ifVal = fir::getBase(
- converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
+ converter.genExprValue(std::get<omp::SomeExpr>(clause.t), stmtCtx));
return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(),
ifVal);
}
static void
addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpObjectList &useDeviceClause,
+ const omp::ObjectList &objects,
llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
&useDeviceSymbols) {
- genObjectList(useDeviceClause, converter, operands);
+ genObjectList(objects, converter, operands);
for (mlir::Value &operand : operands) {
checkMapType(operand.getLoc(), operand.getType());
useDeviceTypes.push_back(operand.getType());
useDeviceLocs.push_back(operand.getLoc());
}
- for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) {
- Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
- useDeviceSymbols.push_back(sym);
- }
+ for (const omp::Object &object : objects)
+ useDeviceSymbols.push_back(object.sym);
}
//===----------------------------------------------------------------------===//
@@ -2807,10 +2807,10 @@ bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
bool ClauseProcessor::processAllocate(
llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const {
- return findRepeatableClause<ClauseTy::Allocate>(
- [&](const ClauseTy::Allocate *allocateClause,
+ return findRepeatableClause<omp::clause::Allocate>(
+ [&](const omp::clause::Allocate &clause,
const Fortran::parser::CharBlock &) {
- genAllocateClause(converter, allocateClause->v, allocatorOperands,
+ genAllocateClause(converter, clause, allocatorOperands,
allocateOperands);
});
}
@@ -2827,12 +2827,12 @@ bool ClauseProcessor::processCopyin() const {
if (converter.isPresentShallowLookup(*sym))
converter.copyHostAssociateVar(*sym, copyAssignIP);
};
- bool hasCopyin = findRepeatableClause<ClauseTy::Copyin>(
- [&](const ClauseTy::Copyin *copyinClause,
+ bool hasCopyin = findRepeatableClause<omp::clause::Copyin>(
+ [&](const omp::clause::Copyin &clause,
const Fortran::parser::CharBlock &) {
- const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v;
- for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
- Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
+ for (const omp::Object &object : clause.v) {
+ Fortran::semantics::Symbol *sym = object.sym;
+ assert(sym && "Expecting symbol");
if (const auto *commonDetails =
sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
for (const auto &mem : commonDetails->objects())
@@ -2865,38 +2865,30 @@ bool ClauseProcessor::processDepend(
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- return findRepeatableClause<ClauseTy::Depend>(
- [&](const ClauseTy::Depend *dependClause,
+ return findRepeatableClause<omp::clause::Depend>(
+ [&](const omp::clause::Depend &clause,
const Fortran::parser::CharBlock &) {
- const std::list<Fortran::parser::Designator> &depVal =
- std::get<std::list<Fortran::parser::Designator>>(
- std::get<Fortran::parser::OmpDependClause::InOut>(
- dependClause->v.u)
- .t);
+ assert(std::holds_alternative<omp::clause::Depend::InOut>(clause.u) &&
+ "Only InOut is handled at the moment");
+ const auto &inOut = std::get<omp::clause::Depend::InOut>(clause.u);
+ const auto &objects = std::get<omp::ObjectList>(inOut.t);
+
mlir::omp::ClauseTaskDependAttr dependTypeOperand =
- genDependKindAttr(firOpBuilder, dependClause);
- dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(),
- dependTypeOperand);
- for (const Fortran::parser::Designator &ompObject : depVal) {
- Fortran::semantics::Symbol *sym = nullptr;
- std::visit(
- Fortran::common::visitors{
- [&](const Fortran::parser::DataRef &designator) {
- if (const Fortran::parser::Name *name =
- std::get_if<Fortran::parser::Name>(&designator.u)) {
- sym = name->symbol;
- } else if (std::get_if<Fortran::common::Indirection<
- Fortran::parser::ArrayElement>>(
- &designator.u)) {
- TODO(converter.getCurrentLocation(),
- "array sections not supported for task depend");
- }
- },
- [&](const Fortran::parser::Substring &designator) {
- TODO(converter.getCurrentLocation(),
- "substring not supported for task depend");
- }},
- (ompObject).u);
+ genDependKindAttr(firOpBuilder, clause);
+ dependTypeOperands.append(objects.size(), dependTypeOperand);
+
+ for (const omp::Object &object : objects) {
+ assert(object.dsg && "Expecting designator");
+
+ if (Fortran::evaluate::ExtractSubstring(*object.dsg)) {
+ TODO(converter.getCurrentLocation(),
+ "substring not supported for task depend");
+ } else if (Fortran::evaluate::IsArrayElement(*object.dsg)) {
+ TODO(converter.getCurrentLocation(),
+ "array sections not supported for task depend");
+ }
+
+ Fortran::semantics::Symbol *sym = object.sym;
const mlir::Value variable = converter.getSymbolAddress(*sym);
dependOperands.push_back(variable);
}
@@ -2904,14 +2896,14 @@ bool ClauseProcessor::processDepend(
}
bool ClauseProcessor::processIf(
- Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
+ omp::clause::If::DirectiveNameModifier directiveName,
mlir::Value &result) const {
bool found = false;
- findRepeatableClause<ClauseTy::If>(
- [&](const ClauseTy::If *ifClause,
+ findRepeatableClause<omp::clause::If>(
+ [&](const omp::clause::If &clause,
const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
- mlir::Value operand = getIfClauseOperand(converter, ifClause,
+ mlir::Value operand = getIfClauseOperand(converter, clause,
directiveName, clauseLocation);
// Assume that, at most, a single 'if' clause will be applicable to the
// given directive.
@@ -2925,12 +2917,11 @@ bool ClauseProcessor::processIf(
bool ClauseProcessor::processLink(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
- return findRepeatableClause<ClauseTy::Link>(
- [&](const ClauseTy::Link *linkClause,
- const Fortran::parser::CharBlock &) {
+ return findRepeatableClause<omp::clause::Link>(
+ [&](const omp::clause::Link &clause, const Fortran::parser::CharBlock &) {
// Case: declare target link(var1, var2)...
gatherFuncAndVarSyms(
- linkClause->v, mlir::omp::DeclareTargetCaptureClause::link, result);
+ clause.v, mlir::omp::DeclareTargetCaptureClause::link, result);
});
}
@@ -2967,7 +2958,7 @@ bool ClauseProcessor::processMap(
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- return findRepeatableClause<ClauseTy::Map>(
+ return findRepeatableClause2<ClauseTy::Map>(
[&](const ClauseTy::Map *mapClause,
const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
@@ -3059,43 +3050,41 @@ bool ClauseProcessor::processReduction(
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
const {
- return findRepeatableClause<ClauseTy::Reduction>(
- [&](const ClauseTy::Reduction *reductionClause,
+ return findRepeatableClause<omp::clause::Reduction>(
+ [&](const omp::clause::Reduction &clause,
const Fortran::parser::CharBlock &) {
ReductionProcessor rp;
- rp.addReductionDecl(currentLocation, converter, reductionClause->v,
- reductionVars, reductionDeclSymbols,
- reductionSymbols);
+ rp.addReductionDecl(currentLocation, converter, clause, reductionVars,
+ reductionDeclSymbols, reductionSymbols);
});
}
bool ClauseProcessor::processSectionsReduction(
mlir::Location currentLocation) const {
- return findRepeatableClause<ClauseTy::Reduction>(
- [&](const ClauseTy::Reduction *, const Fortran::parser::CharBlock &) {
+ return findRepeatableClause<omp::clause::Reduction>(
+ [&](const omp::clause::Reduction &, const Fortran::parser::CharBlock &) {
TODO(currentLocation, "OMPC_Reduction");
});
}
bool ClauseProcessor::processTo(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
- return findRepeatableClause<ClauseTy::To>(
- [&](const ClauseTy::To *toClause, const Fortran::parser::CharBlock &) {
+ return findRepeatableClause<omp::clause::To>(
+ [&](const omp::clause::To &clause, const Fortran::parser::CharBlock &) {
// Case: declare target to(func, var1, var2)...
- gatherFuncAndVarSyms(toClause->v,
+ gatherFuncAndVarSyms(clause.v,
mlir::omp::DeclareTargetCaptureClause::to, result);
});
}
bool ClauseProcessor::processEnter(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
- return findRepeatableClause<ClauseTy::Enter>(
- [&](const ClauseTy::Enter *enterClause,
+ return findRepeatableClause<omp::clause::Enter>(
+ [&](const omp::clause::Enter &clause,
const Fortran::parser::CharBlock &) {
// Case: declare target enter(func, var1, var2)...
- gatherFuncAndVarSyms(enterClause->v,
- mlir::omp::DeclareTargetCaptureClause::enter,
- result);
+ gatherFuncAndVarSyms(
+ clause.v, mlir::omp::DeclareTargetCaptureClause::enter, result);
});
}
@@ -3105,11 +3094,11 @@ bool ClauseProcessor::processUseDeviceAddr(
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
const {
- return findRepeatableClause<ClauseTy::UseDeviceAddr>(
- [&](const ClauseTy::UseDeviceAddr *devAddrClause,
+ return findRepeatableClause<omp::clause::UseDeviceAddr>(
+ [&](const omp::clause::UseDeviceAddr &clause,
const Fortran::parser::CharBlock &) {
- addUseDeviceClause(converter, devAddrClause->v, operands,
- useDeviceTypes, useDeviceLocs, useDeviceSymbols);
+ addUseDeviceClause(converter, clause.v, operands, useDeviceTypes,
+ useDeviceLocs, useDeviceSymbols);
});
}
@@ -3119,10 +3108,10 @@ bool ClauseProcessor::processUseDevicePtr(
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
const {
- return findRepeatableClause<ClauseTy::UseDevicePtr>(
- [&](const ClauseTy::UseDevicePtr *devPtrClause,
+ return findRepeatableClause<omp::clause::UseDevicePtr>(
+ [&](const omp::clause::UseDevicePtr &clause,
const Fortran::parser::CharBlock &) {
- addUseDeviceClause(converter, devPtrClause->v, operands, useDeviceTypes,
+ addUseDeviceClause(converter, clause.v, operands, useDeviceTypes,
useDeviceLocs, useDeviceSymbols);
});
}
@@ -3131,7 +3120,7 @@ template <typename T>
bool ClauseProcessor::processMotionClauses(
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
- return findRepeatableClause<T>(
+ return findRepeatableClause2<T>(
[&](const T *motionClause, const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -3701,7 +3690,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
+ cp.processIf(omp::clause::If::DirectiveNameModifier::Parallel,
ifClauseOperand);
cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
cp.processProcBind(procBindKindAttr);
@@ -3800,8 +3789,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
dependOperands;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Task,
- ifClauseOperand);
+ cp.processIf(omp::clause::If::DirectiveNameModifier::Task, ifClauseOperand);
cp.processAllocate(allocatorOperands, allocateOperands);
cp.processDefault();
cp.processFinal(stmtCtx, finalClauseOperand);
@@ -3862,7 +3850,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData,
+ cp.processIf(omp::clause::If::DirectiveNameModifier::TargetData,
ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
@@ -3893,19 +3881,16 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
mlir::UnitAttr nowaitAttr;
llvm::SmallVector<mlir::Value> mapOperands;
- Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
+ omp::clause::If::DirectiveNameModifier directiveName;
llvm::omp::Directive directive;
if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) {
- directiveName =
- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData;
+ directiveName = omp::clause::If::DirectiveNameModifier::TargetEnterData;
directive = llvm::omp::Directive::OMPD_target_enter_data;
} else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) {
- directiveName =
- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
+ directiveName = omp::clause::If::DirectiveNameModifier::TargetExitData;
directive = llvm::omp::Directive::OMPD_target_exit_data;
} else if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
- directiveName =
- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate;
+ directiveName = omp::clause::If::DirectiveNameModifier::TargetUpdate;
directive = llvm::omp::Directive::OMPD_target_update;
} else {
return nullptr;
@@ -4104,8 +4089,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
- ifClauseOperand);
+ cp.processIf(omp::clause::If::DirectiveNameModifier::Target, ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processThreadLimit(stmtCtx, threadLimitOperand);
cp.processNowait(nowaitAttr);
@@ -4218,8 +4202,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
- ifClauseOperand);
+ cp.processIf(omp::clause::If::DirectiveNameModifier::Teams, ifClauseOperand);
cp.processAllocate(allocatorOperands, allocateOperands);
cp.processDefault();
cp.processNumTeams(stmtCtx, numTeamsClauseOperand);
@@ -4258,8 +4241,9 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
if (const auto *objectList{
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
+ omp::ObjectList objects{omp::makeList(*objectList, semaCtx)};
// Case: declare target(func, var1, var2)
- gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
+ gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to,
symbolAndClause);
} else if (const auto *clauseList{
Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
@@ -4373,7 +4357,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,
if (const auto &ompObjectList =
std::get<std::optional<Fortran::parser::OmpObjectList>>(
flushConstruct.t))
- genObjectList(*ompObjectList, converter, operandRange);
+ genObjectList2(*ompObjectList, converter, operandRange);
const auto &memOrderClause =
std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
flushConstruct.t);
@@ -4535,8 +4519,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
loopVarTypeSize);
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
cp.processReduction(loc, reductionVars, reductionDeclSymbols);
- cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
- ifClauseOperand);
+ cp.processIf(omp::clause::If::DirectiveNameModifier::Simd, ifClauseOperand);
cp.processSimdlen(simdlenClauseOperand);
cp.processSafelen(safelenClauseOperand);
cp.processTODO<Fortran::parser::OmpClause::Aligned,
@@ -5338,106 +5321,101 @@ void Fortran::lower::genOpenMPReduction(
const Fortran::parser::OmpClauseList &clauseList) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- for (const Fortran::parser::OmpClause &clause : clauseList.v) {
+ omp::List<omp::Clause> clauses{omp::makeList(clauseList, semaCtx)};
+
+ for (const omp::Clause &clause : clauses) {
if (const auto &reductionClause =
- std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
- const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
- reductionClause->v.t)};
- const auto &objectList{
- std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)};
+ std::get_if<omp::clause::Reduction>(&clause.u)) {
+ const auto &redOperator{
+ std::get<omp::clause::ReductionOperator>(reductionClause->t)};
+ const auto &objectList{std::get<omp::ObjectList>(reductionClause->t)};
if (const auto *reductionOp =
- std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+ std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
const auto &intrinsicOp{
- std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+ std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
reductionOp->u)};
switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Add:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
+ case omp::clause::DefinedOperator::IntrinsicOperator::AND:
+ case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::OR:
+ case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
break;
default:
continue;
}
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
- reductionVal = declOp.getBase();
- mlir::Type reductionType =
- reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
- if (!reductionType.isa<fir::LogicalType>()) {
- if (!reductionType.isIntOrIndexOrFloat())
- continue;
- }
- for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
- if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
- reductionValUse.getOwner())) {
- mlir::Value loadVal = loadOp.getRes();
- if (reductionType.isa<fir::LogicalType>()) {
- mlir::Operation *reductionOp = findReductionChain(loadVal);
- fir::ConvertOp convertOp =
- getConvertFromReductionOp(reductionOp, loadVal);
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal, &convertOp);
- removeStoreOp(reductionOp, reductionVal);
- } else if (mlir::Operation *reductionOp =
- findReductionChain(loadVal, &reductionVal)) {
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal);
- }
+ for (const omp::Object &object : objectList) {
+ if (const Fortran::semantics::Symbol *symbol = object.sym) {
+ mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+ reductionVal = declOp.getBase();
+ mlir::Type reductionType =
+ reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
+ if (!reductionType.isa<fir::LogicalType>()) {
+ if (!reductionType.isIntOrIndexOrFloat())
+ continue;
+ }
+ for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
+ if (auto loadOp =
+ mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+ mlir::Value loadVal = loadOp.getRes();
+ if (reductionType.isa<fir::LogicalType>()) {
+ mlir::Operation *reductionOp = findReductionChain(loadVal);
+ fir::ConvertOp convertOp =
+ getConvertFromReductionOp(reductionOp, loadVal);
+ updateReduction(reductionOp, firOpBuilder, loadVal,
+ reductionVal, &convertOp);
+ removeStoreOp(reductionOp, reductionVal);
+ } else if (mlir::Operation *reductionOp =
+ findReductionChain(loadVal, &reductionVal)) {
+ updateReduction(reductionOp, firOpBuilder, loadVal,
+ reductionVal);
}
}
}
}
}
} else if (const auto *reductionIntrinsic =
- std::get_if<Fortran::parser::ProcedureDesignator>(
+ std::get_if<omp::clause::ProcedureDesignator>(
&redOperator.u)) {
if (!ReductionProcessor::supportedIntrinsicProcReduction(
*reductionIntrinsic))
continue;
ReductionProcessor::ReductionIdentifier redId =
ReductionProcessor::getReductionType(*reductionIntrinsic);
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
- reductionVal = declOp.getBase();
- for (const mlir::OpOperand &reductionValUse :
- reductionVal.getUses()) {
- if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
- reductionValUse.getOwner())) {
- mlir::Value loadVal = loadOp.getRes();
- // Max is lowered as a compare -> select.
- // Match the pattern here.
- mlir::Operation *reductionOp =
- findReductionChain(loadVal, &reductionVal);
- if (reductionOp == nullptr)
- continue;
-
- if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
- redId == ReductionProcessor::ReductionIdentifier::MIN) {
- assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
- "Selection Op not found in reduction intrinsic");
- mlir::Operation *compareOp =
- getCompareFromReductionOp(reductionOp, loadVal);
- updateReduction(compareOp, firOpBuilder, loadVal,
- reductionVal);
- }
- if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
- redId == ReductionProcessor::ReductionIdentifier::IEOR ||
- redId == ReductionProcessor::ReductionIdentifier::IAND) {
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal);
- }
+ for (const omp::Object &object : objectList) {
+ if (const Fortran::semantics::Symbol *symbol = object.sym) {
+ mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+ reductionVal = declOp.getBase();
+ for (const mlir::OpOperand &reductionValUse :
+ reductionVal.getUses()) {
+ if (auto loadOp =
+ mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+ mlir::Value loadVal = loadOp.getRes();
+ // Max is lowered as a compare -> select.
+ // Match the pattern here.
+ mlir::Operation *reductionOp =
+ findReductionChain(loadVal, &reductionVal);
+ if (reductionOp == nullptr)
+ continue;
+
+ if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
+ redId == ReductionProcessor::ReductionIdentifier::MIN) {
+ assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
+ "Selection Op not found in reduction intrinsic");
+ mlir::Operation *compareOp =
+ getCompareFromReductionOp(reductionOp, loadVal);
+ updateReduction(compareOp, firOpBuilder, loadVal,
+ reductionVal);
+ }
+ if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
+ redId == ReductionProcessor::ReductionIdentifier::IEOR ||
+ redId == ReductionProcessor::ReductionIdentifier::IAND) {
+ updateReduction(reductionOp, firOpBuilder, loadVal,
+ reductionVal);
}
}
}
>From 23eeb2b54d51047c9d2d2d7bb44e2612e79dcfdd Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Thu, 8 Feb 2024 08:33:40 -0600
Subject: [PATCH 4/6] [flang][Lower] Convert OMP Map and related functions to
evaluate::Expr
The related functions are `gatherDataOperandAddrAndBounds` and
`genBoundsOps`. The former is used in OpenACC as well, and it was
updated to pass evaluate::Expr instead of parser objects.
The difference in the test case comes from unfolded conversions
of index expressions, which are explicitly of type integer(kind=8).
Delete now unused `findRepeatableClause2` and `findClause2`.
Add `AsGenericExpr` that takes std::optional. It already returns optional
Expr. Making it accept an optional Expr as input would reduce the number
of necessary checks when handling frequent optional values in evaluator.
---
flang/include/flang/Evaluate/tools.h | 8 +
flang/lib/Lower/DirectivesCommon.h | 389 ++++++++++++++++-----------
flang/lib/Lower/OpenACC.cpp | 54 ++--
flang/lib/Lower/OpenMP.cpp | 105 +++-----
4 files changed, 311 insertions(+), 245 deletions(-)
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index e9999974944e88..d5713cfe420a2e 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -148,6 +148,14 @@ inline Expr<SomeType> AsGenericExpr(Expr<SomeType> &&x) { return std::move(x); }
std::optional<Expr<SomeType>> AsGenericExpr(DataRef &&);
std::optional<Expr<SomeType>> AsGenericExpr(const Symbol &);
+// Propagate std::optional from input to output.
+template <typename A>
+std::optional<Expr<SomeType>> AsGenericExpr(std::optional<A> &&x) {
+ if (!x)
+ return std::nullopt;
+ return AsGenericExpr(std::move(*x));
+}
+
template <typename A>
common::IfNoLvalue<Expr<SomeKind<ResultType<A>::category>>, A> AsCategoryExpr(
A &&x) {
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 8d560db34e05bf..2fa90572bc63eb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -808,6 +808,75 @@ genBaseBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
return bounds;
}
+namespace detail {
+template <typename T> //
+static T &&AsRvalueRef(T &&t) {
+ return std::move(t);
+}
+template <typename T> //
+static T AsRvalueRef(T &t) {
+ return t;
+}
+template <typename T> //
+static T AsRvalueRef(const T &t) {
+ return t;
+}
+
+// Helper class for stripping enclosing parentheses and a conversion that
+// preserves type category. This is used for triplet elements, which are
+// always of type integer(kind=8). The lower/upper bounds are converted to
+// an "index" type, which is 64-bit, so the explicit conversion to kind=8
+// (if present) is not needed. When it's present, though, it causes generated
+// names to contain "int(..., kind=8)".
+struct PeelConvert {
+ template <Fortran::common::TypeCategory Category, int Kind>
+ static Fortran::semantics::MaybeExpr visit_with_category(
+ const Fortran::evaluate::Expr<Fortran::evaluate::Type<Category, Kind>>
+ &expr) {
+ return std::visit(
+ [](auto &&s) { return visit_with_category<Category, Kind>(s); },
+ expr.u);
+ }
+ template <Fortran::common::TypeCategory Category, int Kind>
+ static Fortran::semantics::MaybeExpr visit_with_category(
+ const Fortran::evaluate::Convert<Fortran::evaluate::Type<Category, Kind>,
+ Category> &expr) {
+ return AsGenericExpr(AsRvalueRef(expr.left()));
+ }
+ template <Fortran::common::TypeCategory Category, int Kind, typename T>
+ static Fortran::semantics::MaybeExpr visit_with_category(const T &) {
+ return std::nullopt; //
+ }
+ template <Fortran::common::TypeCategory Category, typename T>
+ static Fortran::semantics::MaybeExpr visit_with_category(const T &) {
+ return std::nullopt; //
+ }
+
+ template <Fortran::common::TypeCategory Category>
+ static Fortran::semantics::MaybeExpr
+ visit(const Fortran::evaluate::Expr<Fortran::evaluate::SomeKind<Category>>
+ &expr) {
+ return std::visit([](auto &&s) { return visit_with_category<Category>(s); },
+ expr.u);
+ }
+ static Fortran::semantics::MaybeExpr
+ visit(const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr) {
+ return std::visit([](auto &&s) { return visit(s); }, expr.u);
+ }
+ template <typename T> //
+ static Fortran::semantics::MaybeExpr visit(const T &) {
+ return std::nullopt;
+ }
+};
+
+static Fortran::semantics::SomeExpr
+peelOuterConvert(Fortran::semantics::SomeExpr &expr) {
+ if (auto peeled = PeelConvert::visit(expr))
+ return *peeled;
+ return expr;
+}
+} // namespace detail
+
/// Generate bounds operations for an array section when subscripts are
/// provided.
template <typename BoundsOp, typename BoundsType>
@@ -815,7 +884,7 @@ llvm::SmallVector<mlir::Value>
genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext &stmtCtx,
- const std::list<Fortran::parser::SectionSubscript> &subscripts,
+ const std::vector<Fortran::evaluate::Subscript> &subscripts,
std::stringstream &asFortran, fir::ExtendedValue &dataExv,
bool dataExvIsAssumedSize, AddrAndBoundsInfo &info,
bool treatIndexAsSection = false) {
@@ -828,8 +897,7 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
const int dataExvRank = static_cast<int>(dataExv.rank());
for (const auto &subscript : subscripts) {
- const auto *triplet{
- std::get_if<Fortran::parser::SubscriptTriplet>(&subscript.u)};
+ const auto *triplet{std::get_if<Fortran::evaluate::Triplet>(&subscript.u)};
if (triplet || treatIndexAsSection) {
if (dimension != 0)
asFortran << ',';
@@ -868,13 +936,18 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
strideInBytes = true;
}
- const Fortran::lower::SomeExpr *lower{nullptr};
+ Fortran::semantics::MaybeExpr lower;
if (triplet) {
- if (const auto &tripletLb{std::get<0>(triplet->t)})
- lower = Fortran::semantics::GetExpr(*tripletLb);
+ if ((lower = Fortran::evaluate::AsGenericExpr(triplet->lower())))
+ lower = detail::peelOuterConvert(*lower);
} else {
- const auto &index{std::get<Fortran::parser::IntExpr>(subscript.u)};
- lower = Fortran::semantics::GetExpr(index);
+ // Case of IndirectSubscriptIntegerExpr
+ using IndirectSubscriptIntegerExpr =
+ Fortran::evaluate::IndirectSubscriptIntegerExpr;
+ using SubscriptInteger = Fortran::evaluate::SubscriptInteger;
+ Fortran::evaluate::Expr<SubscriptInteger> oneInt =
+ std::get<IndirectSubscriptIntegerExpr>(subscript.u).value();
+ lower = Fortran::evaluate::AsGenericExpr(std::move(oneInt));
if (lower->Rank() > 0) {
mlir::emitError(
loc, "vector subscript cannot be used for an array section");
@@ -912,10 +985,12 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
extent = one;
} else {
asFortran << ':';
- const auto &upper{std::get<1>(triplet->t)};
+ Fortran::semantics::MaybeExpr upper =
+ Fortran::evaluate::AsGenericExpr(triplet->upper());
if (upper) {
- uval = Fortran::semantics::GetIntValue(upper);
+ upper = detail::peelOuterConvert(*upper);
+ uval = Fortran::evaluate::ToInt64(*upper);
if (uval) {
if (defaultLb) {
ubound = builder.createIntegerConstant(loc, idxTy, *uval - 1);
@@ -925,22 +1000,21 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
}
asFortran << *uval;
} else {
- const Fortran::lower::SomeExpr *uexpr =
- Fortran::semantics::GetExpr(*upper);
mlir::Value ub =
- fir::getBase(converter.genExprValue(loc, *uexpr, stmtCtx));
+ fir::getBase(converter.genExprValue(loc, *upper, stmtCtx));
ub = builder.createConvert(loc, baseLb.getType(), ub);
ubound = builder.create<mlir::arith::SubIOp>(loc, ub, baseLb);
- asFortran << uexpr->AsFortran();
+ asFortran << upper->AsFortran();
}
}
if (lower && upper) {
if (lval && uval && *uval < *lval) {
mlir::emitError(loc, "zero sized array section");
break;
- } else if (std::get<2>(triplet->t)) {
- const auto &strideExpr{std::get<2>(triplet->t)};
- if (strideExpr) {
+ } else {
+ // Stride is mandatory in evaluate::Triplet. Make sure it's 1.
+ auto val = Fortran::evaluate::ToInt64(triplet->GetStride());
+ if (!val || *val != 1) {
mlir::emitError(loc, "stride cannot be specified on "
"an array section");
break;
@@ -993,150 +1067,157 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
return bounds;
}
-template <typename ObjectType, typename BoundsOp, typename BoundsType>
+namespace detail {
+template <typename Ref, typename Expr> //
+std::optional<Ref> getRef(Expr &&expr) {
+ if constexpr (std::is_same_v<llvm::remove_cvref_t<Expr>,
+ Fortran::evaluate::DataRef>) {
+ if (auto *ref = std::get_if<Ref>(&expr.u))
+ return *ref;
+ return std::nullopt;
+ } else {
+ auto maybeRef = Fortran::evaluate::ExtractDataRef(expr);
+ if (!maybeRef || !std::holds_alternative<Ref>(maybeRef->u))
+ return std::nullopt;
+ return std::get<Ref>(maybeRef->u);
+ }
+}
+} // namespace detail
+
+template <typename BoundsOp, typename BoundsType>
AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder,
- Fortran::semantics::SemanticsContext &semanticsContext,
- Fortran::lower::StatementContext &stmtCtx, const ObjectType &object,
+ semantics::SemanticsContext &semaCtx,
+ Fortran::lower::StatementContext &stmtCtx,
+ Fortran::semantics::SymbolRef symbol,
+ const Fortran::semantics::MaybeExpr &maybeDesignator,
mlir::Location operandLocation, std::stringstream &asFortran,
llvm::SmallVector<mlir::Value> &bounds, bool treatIndexAsSection = false) {
+ using namespace Fortran;
+
AddrAndBoundsInfo info;
- std::visit(
- Fortran::common::visitors{
- [&](const Fortran::parser::Designator &designator) {
- if (auto expr{Fortran::semantics::AnalyzeExpr(semanticsContext,
- designator)}) {
- if (((*expr).Rank() > 0 || treatIndexAsSection) &&
- Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
- designator)) {
- const auto *arrayElement =
- Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
- designator);
- const auto *dataRef =
- std::get_if<Fortran::parser::DataRef>(&designator.u);
- fir::ExtendedValue dataExv;
- bool dataExvIsAssumedSize = false;
- if (Fortran::parser::Unwrap<
- Fortran::parser::StructureComponent>(
- arrayElement->base)) {
- auto exprBase = Fortran::semantics::AnalyzeExpr(
- semanticsContext, arrayElement->base);
- dataExv = converter.genExprAddr(operandLocation, *exprBase,
- stmtCtx);
- info.addr = fir::getBase(dataExv);
- info.rawInput = info.addr;
- asFortran << (*exprBase).AsFortran();
- } else {
- const Fortran::parser::Name &name =
- Fortran::parser::GetLastName(*dataRef);
- dataExvIsAssumedSize = Fortran::semantics::IsAssumedSizeArray(
- name.symbol->GetUltimate());
- info = getDataOperandBaseAddr(converter, builder,
- *name.symbol, operandLocation);
- dataExv = converter.getSymbolExtendedValue(*name.symbol);
- asFortran << name.ToString();
- }
-
- if (!arrayElement->subscripts.empty()) {
- asFortran << '(';
- bounds = genBoundsOps<BoundsOp, BoundsType>(
- builder, operandLocation, converter, stmtCtx,
- arrayElement->subscripts, asFortran, dataExv,
- dataExvIsAssumedSize, info, treatIndexAsSection);
- }
- asFortran << ')';
- } else if (auto structComp = Fortran::parser::Unwrap<
- Fortran::parser::StructureComponent>(designator)) {
- fir::ExtendedValue compExv =
- converter.genExprAddr(operandLocation, *expr, stmtCtx);
- info.addr = fir::getBase(compExv);
- info.rawInput = info.addr;
- if (fir::unwrapRefType(info.addr.getType())
- .isa<fir::SequenceType>())
- bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
- builder, operandLocation, converter, compExv,
- /*isAssumedSize=*/false);
- asFortran << (*expr).AsFortran();
-
- bool isOptional = Fortran::semantics::IsOptional(
- *Fortran::parser::GetLastName(*structComp).symbol);
- if (isOptional)
- info.isPresent = builder.create<fir::IsPresentOp>(
- operandLocation, builder.getI1Type(), info.rawInput);
-
- if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>(
- info.addr.getDefiningOp())) {
- if (fir::isAllocatableType(loadOp.getType()) ||
- fir::isPointerType(loadOp.getType()))
- info.addr = builder.create<fir::BoxAddrOp>(operandLocation,
- info.addr);
- info.rawInput = info.addr;
- }
-
- // If the component is an allocatable or pointer the result of
- // genExprAddr will be the result of a fir.box_addr operation or
- // a fir.box_addr has been inserted just before.
- // Retrieve the box so we handle it like other descriptor.
- if (auto boxAddrOp = mlir::dyn_cast_or_null<fir::BoxAddrOp>(
- info.addr.getDefiningOp())) {
- info.addr = boxAddrOp.getVal();
- info.rawInput = info.addr;
- bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
- builder, operandLocation, converter, compExv, info);
- }
- } else {
- if (Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
- designator)) {
- // Single array element.
- const auto *arrayElement =
- Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
- designator);
- (void)arrayElement;
- fir::ExtendedValue compExv =
- converter.genExprAddr(operandLocation, *expr, stmtCtx);
- info.addr = fir::getBase(compExv);
- info.rawInput = info.addr;
- asFortran << (*expr).AsFortran();
- } else if (const auto *dataRef{
- std::get_if<Fortran::parser::DataRef>(
- &designator.u)}) {
- // Scalar or full array.
- const Fortran::parser::Name &name =
- Fortran::parser::GetLastName(*dataRef);
- fir::ExtendedValue dataExv =
- converter.getSymbolExtendedValue(*name.symbol);
- info = getDataOperandBaseAddr(converter, builder,
- *name.symbol, operandLocation);
- if (fir::unwrapRefType(info.addr.getType())
- .isa<fir::BaseBoxType>()) {
- bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
- builder, operandLocation, converter, dataExv, info);
- }
- bool dataExvIsAssumedSize =
- Fortran::semantics::IsAssumedSizeArray(
- name.symbol->GetUltimate());
- if (fir::unwrapRefType(info.addr.getType())
- .isa<fir::SequenceType>())
- bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
- builder, operandLocation, converter, dataExv,
- dataExvIsAssumedSize);
- asFortran << name.ToString();
- } else { // Unsupported
- llvm::report_fatal_error(
- "Unsupported type of OpenACC operand");
- }
- }
- }
- },
- [&](const Fortran::parser::Name &name) {
- info = getDataOperandBaseAddr(converter, builder, *name.symbol,
- operandLocation);
- asFortran << name.ToString();
- }},
- object.u);
+
+ if (!maybeDesignator) {
+ info = getDataOperandBaseAddr(converter, builder, symbol, operandLocation);
+ asFortran << symbol->name().ToString();
+ return info;
+ }
+
+ semantics::SomeExpr designator = *maybeDesignator;
+
+ if ((designator.Rank() > 0 || treatIndexAsSection) &&
+ IsArrayElement(designator)) {
+ auto arrayRef = detail::getRef<evaluate::ArrayRef>(designator);
+ // This shouldn't fail after IsArrayElement(designator).
+ assert(arrayRef && "Expecting ArrayRef");
+
+ fir::ExtendedValue dataExv;
+ bool dataExvIsAssumedSize = false;
+
+ auto toMaybeExpr = [&](auto &&base) {
+ using BaseType = llvm::remove_cvref_t<decltype(base)>;
+ evaluate::ExpressionAnalyzer ea{semaCtx};
+
+ if constexpr (std::is_same_v<evaluate::NamedEntity, BaseType>) {
+ if (auto *ref = base.UnwrapSymbolRef())
+ return ea.Designate(evaluate::DataRef{*ref});
+ if (auto *ref = base.UnwrapComponent())
+ return ea.Designate(evaluate::DataRef{*ref});
+ llvm_unreachable("Unexpected NamedEntity");
+ } else {
+ static_assert(std::is_same_v<semantics::SymbolRef, BaseType>);
+ return ea.Designate(evaluate::DataRef{base});
+ }
+ };
+
+ auto arrayBase = toMaybeExpr(arrayRef->base());
+ assert(arrayBase);
+
+ if (detail::getRef<evaluate::Component>(*arrayBase)) {
+ dataExv = converter.genExprAddr(operandLocation, *arrayBase, stmtCtx);
+ info.addr = fir::getBase(dataExv);
+ info.rawInput = info.addr;
+ asFortran << arrayBase->AsFortran();
+ } else {
+ const semantics::Symbol &sym = arrayRef->GetLastSymbol();
+ dataExvIsAssumedSize =
+ Fortran::semantics::IsAssumedSizeArray(sym.GetUltimate());
+ info = getDataOperandBaseAddr(converter, builder, sym, operandLocation);
+ dataExv = converter.getSymbolExtendedValue(sym);
+ asFortran << sym.name().ToString();
+ }
+
+ if (!arrayRef->subscript().empty()) {
+ asFortran << '(';
+ bounds = genBoundsOps<BoundsOp, BoundsType>(
+ builder, operandLocation, converter, stmtCtx, arrayRef->subscript(),
+ asFortran, dataExv, dataExvIsAssumedSize, info, treatIndexAsSection);
+ }
+ asFortran << ')';
+ } else if (auto compRef = detail::getRef<evaluate::Component>(designator)) {
+ fir::ExtendedValue compExv =
+ converter.genExprAddr(operandLocation, designator, stmtCtx);
+ info.addr = fir::getBase(compExv);
+ info.rawInput = info.addr;
+ if (fir::unwrapRefType(info.addr.getType()).isa<fir::SequenceType>())
+ bounds = genBaseBoundsOps<BoundsOp, BoundsType>(builder, operandLocation,
+ converter, compExv,
+ /*isAssumedSize=*/false);
+ asFortran << designator.AsFortran();
+
+ if (semantics::IsOptional(compRef->GetLastSymbol())) {
+ info.isPresent = builder.create<fir::IsPresentOp>(
+ operandLocation, builder.getI1Type(), info.rawInput);
+ }
+
+ if (auto loadOp =
+ mlir::dyn_cast_or_null<fir::LoadOp>(info.addr.getDefiningOp())) {
+ if (fir::isAllocatableType(loadOp.getType()) ||
+ fir::isPointerType(loadOp.getType()))
+ info.addr = builder.create<fir::BoxAddrOp>(operandLocation, info.addr);
+ info.rawInput = info.addr;
+ }
+
+ // If the component is an allocatable or pointer the result of
+ // genExprAddr will be the result of a fir.box_addr operation or
+ // a fir.box_addr has been inserted just before.
+ // Retrieve the box so we handle it like other descriptor.
+ if (auto boxAddrOp =
+ mlir::dyn_cast_or_null<fir::BoxAddrOp>(info.addr.getDefiningOp())) {
+ info.addr = boxAddrOp.getVal();
+ info.rawInput = info.addr;
+ bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
+ builder, operandLocation, converter, compExv, info);
+ }
+ } else {
+ if (detail::getRef<evaluate::ArrayRef>(designator)) {
+ fir::ExtendedValue compExv =
+ converter.genExprAddr(operandLocation, designator, stmtCtx);
+ info.addr = fir::getBase(compExv);
+ info.rawInput = info.addr;
+ asFortran << designator.AsFortran();
+ } else if (auto symRef = detail::getRef<semantics::SymbolRef>(designator)) {
+ // Scalar or full array.
+ fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(*symRef);
+ info =
+ getDataOperandBaseAddr(converter, builder, *symRef, operandLocation);
+ if (fir::unwrapRefType(info.addr.getType()).isa<fir::BaseBoxType>()) {
+ bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
+ builder, operandLocation, converter, dataExv, info);
+ }
+ bool dataExvIsAssumedSize =
+ Fortran::semantics::IsAssumedSizeArray(symRef->get().GetUltimate());
+ if (fir::unwrapRefType(info.addr.getType()).isa<fir::SequenceType>())
+ bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
+ builder, operandLocation, converter, dataExv, dataExvIsAssumedSize);
+ asFortran << symRef->get().name().ToString();
+ } else { // Unsupported
+ llvm::report_fatal_error("Unsupported type of OpenACC operand");
+ }
+ }
+
return info;
}
-
} // namespace lower
} // namespace Fortran
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 6ae270f63f5cf4..a444682306ac20 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -269,6 +269,11 @@ getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) {
Fortran::parser::GetLastName(arrayElement->base);
return *name.symbol;
}
+ if (const auto *component =
+ Fortran::parser::Unwrap<Fortran::parser::StructureComponent>(
+ *designator)) {
+ return *component->component.symbol;
+ }
} else if (const auto *name =
std::get_if<Fortran::parser::Name>(&accObject.u)) {
return *name->symbol;
@@ -286,17 +291,20 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
mlir::acc::DataClause dataClause, bool structured,
bool implicit, bool setDeclareAttr = false) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
+ Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
+ Fortran::semantics::MaybeExpr designator =
+ std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
- mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
- stmtCtx, accObject, operandLocation,
- asFortran, bounds,
- /*treatIndexAsSection=*/true);
+ mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
+ converter, builder, semanticsContext, stmtCtx, symbol, designator,
+ operandLocation, asFortran, bounds,
+ /*treatIndexAsSection=*/true);
// If the input value is optional and is not a descriptor, we use the
// rawInput directly.
@@ -321,16 +329,19 @@ static void genDeclareDataOperandOperations(
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
mlir::acc::DataClause dataClause, bool structured, bool implicit) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
+ Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
+ Fortran::semantics::MaybeExpr designator =
+ std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
- mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
- stmtCtx, accObject, operandLocation,
- asFortran, bounds);
+ mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
+ converter, builder, semanticsContext, stmtCtx, symbol, designator,
+ operandLocation, asFortran, bounds);
EntryOp op = createDataEntryOp<EntryOp>(
builder, operandLocation, info.addr, asFortran, bounds, structured,
implicit, dataClause, info.addr.getType());
@@ -339,8 +350,7 @@ static void genDeclareDataOperandOperations(
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(info.addr.getType()))) {
mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion());
modBuilder.setInsertionPointAfter(builder.getFunction());
- std::string prefix =
- converter.mangleName(getSymbolFromAccObject(accObject));
+ std::string prefix = converter.mangleName(symbol);
createDeclareAllocFuncWithArg<EntryOp>(
modBuilder, builder, operandLocation, info.addr.getType(), prefix,
asFortran, dataClause);
@@ -783,16 +793,19 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
llvm::SmallVector<mlir::Attribute> &privatizations) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
+ Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
+ Fortran::semantics::MaybeExpr designator =
+ std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
- mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
- stmtCtx, accObject, operandLocation,
- asFortran, bounds);
+ mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
+ converter, builder, semanticsContext, stmtCtx, symbol, designator,
+ operandLocation, asFortran, bounds);
RecipeOp recipe;
mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType());
if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
@@ -1361,16 +1374,19 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
const auto &op =
std::get<Fortran::parser::AccReductionOperator>(objectList.t);
mlir::acc::ReductionOperator mlirOp = getReductionOperator(op);
+ Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objects.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
+ Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
+ Fortran::semantics::MaybeExpr designator =
+ std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
- mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
- stmtCtx, accObject, operandLocation,
- asFortran, bounds);
+ mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
+ converter, builder, semanticsContext, stmtCtx, symbol, designator,
+ operandLocation, asFortran, bounds);
mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType());
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 921649b868040c..ad38d0aeb16ad6 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1789,18 +1789,6 @@ class ClauseProcessor {
return end;
}
- /// Utility to find a clause within a range in the clause list.
- template <typename T>
- static ClauseIterator2 findClause2(ClauseIterator2 begin,
- ClauseIterator2 end) {
- for (ClauseIterator2 it = begin; it != end; ++it) {
- if (std::get_if<T>(&it->u))
- return it;
- }
-
- return end;
- }
-
/// Return the first instance of the given clause found in the clause list or
/// `nullptr` if not present. If more than one instance is expected, use
/// `findRepeatableClause` instead.
@@ -1836,26 +1824,6 @@ class ClauseProcessor {
return found;
}
- /// Call `callbackFn` for each occurrence of the given clause. Return `true`
- /// if at least one instance was found.
- template <typename T>
- bool findRepeatableClause2(
- std::function<void(const T *, const Fortran::parser::CharBlock &source)>
- callbackFn) const {
- bool found = false;
- ClauseIterator2 nextIt, endIt = clauses2.v.end();
- for (ClauseIterator2 it = clauses2.v.begin(); it != endIt; it = nextIt) {
- nextIt = findClause2<T>(it, endIt);
-
- if (nextIt != endIt) {
- callbackFn(&std::get<T>(nextIt->u), nextIt->source);
- found = true;
- ++nextIt;
- }
- }
- return found;
- }
-
/// Set the `result` to a new `mlir::UnitAttr` if the clause is present.
template <typename T>
bool markClauseOccurrence(mlir::UnitAttr &result) const {
@@ -2958,65 +2926,61 @@ bool ClauseProcessor::processMap(
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- return findRepeatableClause2<ClauseTy::Map>(
- [&](const ClauseTy::Map *mapClause,
+ return findRepeatableClause<omp::clause::Map>(
+ [&](const omp::clause::Map &clause,
const Fortran::parser::CharBlock &source) {
+ using Map = omp::clause::Map;
mlir::Location clauseLocation = converter.genLocation(source);
- const auto &oMapType =
- std::get<std::optional<Fortran::parser::OmpMapType>>(
- mapClause->v.t);
+ const auto &oMapType = std::get<std::optional<Map::MapType>>(clause.t);
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
// If the map type is specified, then process it else Tofrom is the
// default.
if (oMapType) {
- const Fortran::parser::OmpMapType::Type &mapType =
- std::get<Fortran::parser::OmpMapType::Type>(oMapType->t);
+ const Map::MapType::Type &mapType =
+ std::get<Map::MapType::Type>(oMapType->t);
switch (mapType) {
- case Fortran::parser::OmpMapType::Type::To:
+ case Map::MapType::Type::To:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
break;
- case Fortran::parser::OmpMapType::Type::From:
+ case Map::MapType::Type::From:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
break;
- case Fortran::parser::OmpMapType::Type::Tofrom:
+ case Map::MapType::Type::Tofrom:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
break;
- case Fortran::parser::OmpMapType::Type::Alloc:
- case Fortran::parser::OmpMapType::Type::Release:
+ case Map::MapType::Type::Alloc:
+ case Map::MapType::Type::Release:
// alloc and release is the default map_type for the Target Data
// Ops, i.e. if no bits for map_type is supplied then alloc/release
// is implicitly assumed based on the target directive. Default
// value for Target Data and Enter Data is alloc and for Exit Data
// it is release.
break;
- case Fortran::parser::OmpMapType::Type::Delete:
+ case Map::MapType::Type::Delete:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
}
- if (std::get<std::optional<Fortran::parser::OmpMapType::Always>>(
- oMapType->t))
+ if (std::get<std::optional<Map::MapType::Always>>(oMapType->t))
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
} else {
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
}
- for (const Fortran::parser::OmpObject &ompObject :
- std::get<Fortran::parser::OmpObjectList>(mapClause->v.t).v) {
+ for (const omp::Object &object : std::get<omp::ObjectList>(clause.t)) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::OmpObject, mlir::omp::DataBoundsOp,
- mlir::omp::DataBoundsType>(
- converter, firOpBuilder, semaCtx, stmtCtx, ompObject,
- clauseLocation, asFortran, bounds, treatIndexAsSection);
+ mlir::omp::DataBoundsOp, mlir::omp::DataBoundsType>(
+ converter, firOpBuilder, semaCtx, stmtCtx, *object.sym,
+ object.dsg, clauseLocation, asFortran, bounds,
+ treatIndexAsSection);
- auto origSymbol =
- converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
+ auto origSymbol = converter.getSymbolAddress(*object.sym);
mlir::Value symAddr = info.addr;
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
symAddr = origSymbol;
@@ -3039,7 +3003,7 @@ bool ClauseProcessor::processMap(
mapSymLocs->push_back(symAddr.getLoc());
if (mapSymbols)
- mapSymbols->push_back(getOmpObjectSymbol(ompObject));
+ mapSymbols->push_back(object.sym);
}
});
}
@@ -3120,32 +3084,31 @@ template <typename T>
bool ClauseProcessor::processMotionClauses(
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
- return findRepeatableClause2<T>(
- [&](const T *motionClause, const Fortran::parser::CharBlock &source) {
+ return findRepeatableClause<T>(
+ [&](const T &clause, const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- static_assert(std::is_same_v<T, ClauseProcessor::ClauseTy::To> ||
- std::is_same_v<T, ClauseProcessor::ClauseTy::From>);
+ static_assert(std::is_same_v<T, omp::clause::To> ||
+ std::is_same_v<T, omp::clause::From>);
// TODO Support motion modifiers: present, mapper, iterator.
constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
- std::is_same_v<T, ClauseProcessor::ClauseTy::To>
+ std::is_same_v<T, omp::clause::To>
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
- for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) {
+ for (const omp::Object &object : clause.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::OmpObject, mlir::omp::DataBoundsOp,
- mlir::omp::DataBoundsType>(
- converter, firOpBuilder, semaCtx, stmtCtx, ompObject,
- clauseLocation, asFortran, bounds, treatIndexAsSection);
+ mlir::omp::DataBoundsOp, mlir::omp::DataBoundsType>(
+ converter, firOpBuilder, semaCtx, stmtCtx, *object.sym,
+ object.dsg, clauseLocation, asFortran, bounds,
+ treatIndexAsSection);
- auto origSymbol =
- converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
+ auto origSymbol = converter.getSymbolAddress(*object.sym);
mlir::Value symAddr = info.addr;
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
symAddr = origSymbol;
@@ -3902,10 +3865,8 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
cp.processNowait(nowaitAttr);
if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
- cp.processMotionClauses<Fortran::parser::OmpClause::To>(stmtCtx,
- mapOperands);
- cp.processMotionClauses<Fortran::parser::OmpClause::From>(stmtCtx,
- mapOperands);
+ cp.processMotionClauses<omp::clause::To>(stmtCtx, mapOperands);
+ cp.processMotionClauses<omp::clause::From>(stmtCtx, mapOperands);
} else {
cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
>From 8b774cd5e89ce935d43a6fa571bee53435dd71fc Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Fri, 9 Feb 2024 15:03:54 -0600
Subject: [PATCH 5/6] [flang][OpenMP] Convert processTODO and remove unused
objects
Remove `ClauseIterator2` and `clauses2` from ClauseProcessor.
---
flang/lib/Lower/OpenMP.cpp | 75 ++++++++++++++------------------------
1 file changed, 28 insertions(+), 47 deletions(-)
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index ad38d0aeb16ad6..1aafa07c3b71f0 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1669,13 +1669,11 @@ void DataSharingProcessor::defaultPrivatize() {
/// methods that relate to clauses that can impact the lowering of that
/// construct.
class ClauseProcessor {
- using ClauseTy = Fortran::parser::OmpClause;
-
public:
ClauseProcessor(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
const Fortran::parser::OmpClauseList &clauses)
- : converter(converter), semaCtx(semaCtx), clauses2(clauses),
+ : converter(converter), semaCtx(semaCtx),
clauses(omp::makeList(clauses, semaCtx)) {}
// 'Unique' clauses: They can appear at most once in the clause list.
@@ -1776,7 +1774,6 @@ class ClauseProcessor {
private:
using ClauseIterator = omp::List<omp::Clause>::const_iterator;
- using ClauseIterator2 = std::list<ClauseTy>::const_iterator;
/// Utility to find a clause within a range in the clause list.
template <typename T>
@@ -1836,7 +1833,6 @@ class ClauseProcessor {
Fortran::lower::AbstractConverter &converter;
Fortran::semantics::SemanticsContext &semaCtx;
- const Fortran::parser::OmpClauseList &clauses2;
omp::List<omp::Clause> clauses;
};
@@ -3132,19 +3128,17 @@ bool ClauseProcessor::processMotionClauses(
template <typename... Ts>
void ClauseProcessor::processTODO(mlir::Location currentLocation,
llvm::omp::Directive directive) const {
- auto checkUnhandledClause = [&](const auto *x) {
+ auto checkUnhandledClause = [&](llvm::omp::Clause id, const auto *x) {
if (!x)
return;
TODO(currentLocation,
- "Unhandled clause " +
- llvm::StringRef(Fortran::parser::ParseTreeDumper::GetNodeName(*x))
- .upper() +
+ "Unhandled clause " + llvm::omp::getOpenMPClauseName(id).upper() +
" in " + llvm::omp::getOpenMPDirectiveName(directive).upper() +
" construct");
};
- for (ClauseIterator2 it = clauses2.v.begin(); it != clauses2.v.end(); ++it)
- (checkUnhandledClause(std::get_if<Ts>(&it->u)), ...);
+ for (ClauseIterator it = clauses.begin(); it != clauses.end(); ++it)
+ (checkUnhandledClause(it->id, std::get_if<Ts>(&it->u)), ...);
}
//===----------------------------------------------------------------------===//
@@ -3725,8 +3719,8 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
ClauseProcessor cp(converter, semaCtx, beginClauseList);
cp.processAllocate(allocatorOperands, allocateOperands);
- cp.processTODO<Fortran::parser::OmpClause::Copyprivate>(
- currentLocation, llvm::omp::Directive::OMPD_single);
+ cp.processTODO<omp::clause::Copyprivate>(currentLocation,
+ llvm::omp::Directive::OMPD_single);
ClauseProcessor(converter, semaCtx, endClauseList).processNowait(nowaitAttr);
@@ -3760,10 +3754,9 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
cp.processMergeable(mergeableAttr);
cp.processPriority(stmtCtx, priorityClauseOperand);
cp.processDepend(dependTypeOperands, dependOperands);
- cp.processTODO<Fortran::parser::OmpClause::InReduction,
- Fortran::parser::OmpClause::Detach,
- Fortran::parser::OmpClause::Affinity>(
- currentLocation, llvm::omp::Directive::OMPD_task);
+ cp.processTODO<omp::clause::InReduction, omp::clause::Detach,
+ omp::clause::Affinity>(currentLocation,
+ llvm::omp::Directive::OMPD_task);
return genOpWithBody<mlir::omp::TaskOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
@@ -3788,7 +3781,7 @@ genTaskGroupOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processAllocate(allocatorOperands, allocateOperands);
- cp.processTODO<Fortran::parser::OmpClause::TaskReduction>(
+ cp.processTODO<omp::clause::TaskReduction>(
currentLocation, llvm::omp::Directive::OMPD_taskgroup);
return genOpWithBody<mlir::omp::TaskGroupOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
@@ -3872,8 +3865,7 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
}
- cp.processTODO<Fortran::parser::OmpClause::Depend>(currentLocation,
- directive);
+ cp.processTODO<omp::clause::Depend>(currentLocation, directive);
return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand,
deviceOperand, nullptr, mlir::ValueRange(),
@@ -4056,16 +4048,11 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
cp.processNowait(nowaitAttr);
cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes,
&mapSymLocs, &mapSymbols);
- cp.processTODO<Fortran::parser::OmpClause::Private,
- Fortran::parser::OmpClause::Depend,
- Fortran::parser::OmpClause::Firstprivate,
- Fortran::parser::OmpClause::IsDevicePtr,
- Fortran::parser::OmpClause::HasDeviceAddr,
- Fortran::parser::OmpClause::Reduction,
- Fortran::parser::OmpClause::InReduction,
- Fortran::parser::OmpClause::Allocate,
- Fortran::parser::OmpClause::UsesAllocators,
- Fortran::parser::OmpClause::Defaultmap>(
+ cp.processTODO<omp::clause::Private, omp::clause::Depend,
+ omp::clause::Firstprivate, omp::clause::IsDevicePtr,
+ omp::clause::HasDeviceAddr, omp::clause::Reduction,
+ omp::clause::InReduction, omp::clause::Allocate,
+ omp::clause::UsesAllocators, omp::clause::Defaultmap>(
currentLocation, llvm::omp::Directive::OMPD_target);
// 5.8.1 Implicit Data-Mapping Attribute Rules
@@ -4168,8 +4155,8 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
cp.processDefault();
cp.processNumTeams(stmtCtx, numTeamsClauseOperand);
cp.processThreadLimit(stmtCtx, threadLimitClauseOperand);
- cp.processTODO<Fortran::parser::OmpClause::Reduction>(
- currentLocation, llvm::omp::Directive::OMPD_teams);
+ cp.processTODO<omp::clause::Reduction>(currentLocation,
+ llvm::omp::Directive::OMPD_teams);
return genOpWithBody<mlir::omp::TeamsOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
@@ -4221,7 +4208,7 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
cp.processEnter(symbolAndClause);
cp.processLink(symbolAndClause);
cp.processDeviceType(deviceType);
- cp.processTODO<Fortran::parser::OmpClause::Indirect>(
+ cp.processTODO<omp::clause::Indirect>(
converter.getCurrentLocation(),
llvm::omp::Directive::OMPD_declare_target);
}
@@ -4280,8 +4267,7 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
break;
case llvm::omp::Directive::OMPD_taskwait:
ClauseProcessor(converter, semaCtx, opClauseList)
- .processTODO<Fortran::parser::OmpClause::Depend,
- Fortran::parser::OmpClause::Nowait>(
+ .processTODO<omp::clause::Depend, omp::clause::Nowait>(
currentLocation, llvm::omp::Directive::OMPD_taskwait);
firOpBuilder.create<mlir::omp::TaskwaitOp>(currentLocation);
break;
@@ -4483,11 +4469,9 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
cp.processIf(omp::clause::If::DirectiveNameModifier::Simd, ifClauseOperand);
cp.processSimdlen(simdlenClauseOperand);
cp.processSafelen(safelenClauseOperand);
- cp.processTODO<Fortran::parser::OmpClause::Aligned,
- Fortran::parser::OmpClause::Allocate,
- Fortran::parser::OmpClause::Linear,
- Fortran::parser::OmpClause::Nontemporal,
- Fortran::parser::OmpClause::Order>(loc, ompDirective);
+ cp.processTODO<omp::clause::Aligned, omp::clause::Allocate,
+ omp::clause::Linear, omp::clause::Nontemporal,
+ omp::clause::Order>(loc, ompDirective);
convertLoopBounds(converter, loc, lowerBound, upperBound, step,
loopVarTypeSize);
@@ -4544,8 +4528,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
cp.processReduction(loc, reductionVars, reductionDeclSymbols,
&reductionSymbols);
- cp.processTODO<Fortran::parser::OmpClause::Linear,
- Fortran::parser::OmpClause::Order>(loc, ompDirective);
+ cp.processTODO<omp::clause::Linear, omp::clause::Order>(loc, ompDirective);
convertLoopBounds(converter, loc, lowerBound, upperBound, step,
loopVarTypeSize);
@@ -4612,11 +4595,9 @@ static void createSimdWsLoop(
const Fortran::parser::OmpClauseList &beginClauseList,
const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
ClauseProcessor cp(converter, semaCtx, beginClauseList);
- cp.processTODO<
- Fortran::parser::OmpClause::Aligned, Fortran::parser::OmpClause::Allocate,
- Fortran::parser::OmpClause::Linear, Fortran::parser::OmpClause::Safelen,
- Fortran::parser::OmpClause::Simdlen, Fortran::parser::OmpClause::Order>(
- loc, ompDirective);
+ cp.processTODO<omp::clause::Aligned, omp::clause::Allocate,
+ omp::clause::Linear, omp::clause::Safelen,
+ omp::clause::Simdlen, omp::clause::Order>(loc, ompDirective);
// TODO: Add support for vectorization - add vectorization hints inside loop
// body.
// OpenMP standard does not specify the length of vector instructions.
>From ac907eeb19955d22a7e8a06d69930e4b2996690e Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Sat, 10 Feb 2024 08:50:48 -0600
Subject: [PATCH 6/6] [flang][OpenMP] Convert DataSharingProcessor to
omp::Clause
---
flang/lib/Lower/OpenMP.cpp | 303 ++++++++++++++++++-------------------
1 file changed, 149 insertions(+), 154 deletions(-)
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 1aafa07c3b71f0..a179bb178c8d75 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1278,14 +1278,15 @@ class DataSharingProcessor {
llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInNestedRegions;
llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInParentRegions;
Fortran::lower::AbstractConverter &converter;
+ Fortran::semantics::SemanticsContext &semaCtx;
fir::FirOpBuilder &firOpBuilder;
- const Fortran::parser::OmpClauseList &opClauseList;
+ omp::List<omp::Clause> clauses;
Fortran::lower::pft::Evaluation &eval;
bool needBarrier();
void collectSymbols(Fortran::semantics::Symbol::Flag flag);
void collectOmpObjectListSymbol(
- const Fortran::parser::OmpObjectList &ompObjectList,
+ const omp::ObjectList &objects,
llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet);
void collectSymbolsForPrivatization();
void insertBarrier();
@@ -1302,11 +1303,12 @@ class DataSharingProcessor {
public:
DataSharingProcessor(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
const Fortran::parser::OmpClauseList &opClauseList,
Fortran::lower::pft::Evaluation &eval)
- : hasLastPrivateOp(false), converter(converter),
- firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList),
- eval(eval) {}
+ : hasLastPrivateOp(false), converter(converter), semaCtx(semaCtx),
+ firOpBuilder(converter.getFirOpBuilder()),
+ clauses(omp::makeList(opClauseList, semaCtx)), eval(eval) {}
// Privatisation is split into two steps.
// Step1 performs cloning of all privatisation clauses and copying for
// firstprivates. Step1 is performed at the place where process/processStep1
@@ -1384,30 +1386,28 @@ void DataSharingProcessor::copyLastPrivateSymbol(
}
void DataSharingProcessor::collectOmpObjectListSymbol(
- const Fortran::parser::OmpObjectList &ompObjectList,
+ const omp::ObjectList &objects,
llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet) {
- for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
- Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
+ for (const omp::Object &object : objects) {
+ Fortran::semantics::Symbol *sym = object.sym;
symbolSet.insert(sym);
}
}
void DataSharingProcessor::collectSymbolsForPrivatization() {
bool hasCollapse = false;
- for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
+ for (const omp::Clause &clause : clauses) {
if (const auto &privateClause =
- std::get_if<Fortran::parser::OmpClause::Private>(&clause.u)) {
+ std::get_if<omp::clause::Private>(&clause.u)) {
collectOmpObjectListSymbol(privateClause->v, privatizedSymbols);
} else if (const auto &firstPrivateClause =
- std::get_if<Fortran::parser::OmpClause::Firstprivate>(
- &clause.u)) {
+ std::get_if<omp::clause::Firstprivate>(&clause.u)) {
collectOmpObjectListSymbol(firstPrivateClause->v, privatizedSymbols);
} else if (const auto &lastPrivateClause =
- std::get_if<Fortran::parser::OmpClause::Lastprivate>(
- &clause.u)) {
+ std::get_if<omp::clause::Lastprivate>(&clause.u)) {
collectOmpObjectListSymbol(lastPrivateClause->v, privatizedSymbols);
hasLastPrivateOp = true;
- } else if (std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) {
+ } else if (std::get_if<omp::clause::Collapse>(&clause.u)) {
hasCollapse = true;
}
}
@@ -1440,138 +1440,135 @@ void DataSharingProcessor::insertBarrier() {
void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
bool cmpCreated = false;
mlir::OpBuilder::InsertPoint localInsPt = firOpBuilder.saveInsertionPoint();
- for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
- if (std::get_if<Fortran::parser::OmpClause::Lastprivate>(&clause.u)) {
- // TODO: Add lastprivate support for simd construct
- if (mlir::isa<mlir::omp::SectionOp>(op)) {
- if (&eval == &eval.parentConstruct->getLastNestedEvaluation()) {
- // For `omp.sections`, lastprivatized variables occur in
- // lexically final `omp.section` operation. The following FIR
- // shall be generated for the same:
- //
- // omp.sections lastprivate(...) {
- // omp.section {...}
- // omp.section {...}
- // omp.section {
- // fir.allocate for `private`/`firstprivate`
- // <More operations here>
- // fir.if %true {
- // ^%lpv_update_blk
- // }
- // }
- // }
- //
- // To keep code consistency while handling privatization
- // through this control flow, add a `fir.if` operation
- // that always evaluates to true, in order to create
- // a dedicated sub-region in `omp.section` where
- // lastprivate FIR can reside. Later canonicalizations
- // will optimize away this operation.
- if (!eval.lowerAsUnstructured()) {
- auto ifOp = firOpBuilder.create<fir::IfOp>(
- op->getLoc(),
- firOpBuilder.createIntegerConstant(
- op->getLoc(), firOpBuilder.getIntegerType(1), 0x1),
- /*else*/ false);
- firOpBuilder.setInsertionPointToStart(
- &ifOp.getThenRegion().front());
-
- const Fortran::parser::OpenMPConstruct *parentOmpConstruct =
- eval.parentConstruct->getIf<Fortran::parser::OpenMPConstruct>();
- assert(parentOmpConstruct &&
- "Expected a valid enclosing OpenMP construct");
- const Fortran::parser::OpenMPSectionsConstruct *sectionsConstruct =
- std::get_if<Fortran::parser::OpenMPSectionsConstruct>(
- &parentOmpConstruct->u);
- assert(sectionsConstruct &&
- "Expected an enclosing omp.sections construct");
- const Fortran::parser::OmpClauseList §ionsEndClauseList =
- std::get<Fortran::parser::OmpClauseList>(
- std::get<Fortran::parser::OmpEndSectionsDirective>(
- sectionsConstruct->t)
- .t);
- for (const Fortran::parser::OmpClause &otherClause :
- sectionsEndClauseList.v)
- if (std::get_if<Fortran::parser::OmpClause::Nowait>(
- &otherClause.u))
- // Emit implicit barrier to synchronize threads and avoid data
- // races on post-update of lastprivate variables when `nowait`
- // clause is present.
- firOpBuilder.create<mlir::omp::BarrierOp>(
- converter.getCurrentLocation());
- firOpBuilder.setInsertionPointToStart(
- &ifOp.getThenRegion().front());
- lastPrivIP = firOpBuilder.saveInsertionPoint();
- firOpBuilder.setInsertionPoint(ifOp);
- insPt = firOpBuilder.saveInsertionPoint();
- } else {
- // Lastprivate operation is inserted at the end
- // of the lexically last section in the sections
- // construct
- mlir::OpBuilder::InsertPoint unstructuredSectionsIP =
- firOpBuilder.saveInsertionPoint();
- mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
- firOpBuilder.setInsertionPoint(lastOper);
- lastPrivIP = firOpBuilder.saveInsertionPoint();
- firOpBuilder.restoreInsertionPoint(unstructuredSectionsIP);
- }
- }
- } else if (mlir::isa<mlir::omp::WsLoopOp>(op)) {
- // Update the original variable just before exiting the worksharing
- // loop. Conversion as follows:
+ for (const omp::Clause &clause : clauses) {
+ if (clause.id != llvm::omp::OMPC_lastprivate)
+ continue;
+ // TODO: Add lastprivate support for simd construct
+ if (mlir::isa<mlir::omp::SectionOp>(op)) {
+ if (&eval == &eval.parentConstruct->getLastNestedEvaluation()) {
+ // For `omp.sections`, lastprivatized variables occur in
+ // lexically final `omp.section` operation. The following FIR
+ // shall be generated for the same:
//
- // omp.wsloop {
- // omp.wsloop { ...
- // ... store
- // store ===> %v = arith.addi %iv, %step
- // omp.yield %cmp = %step < 0 ? %v < %ub : %v > %ub
- // } fir.if %cmp {
- // fir.store %v to %loopIV
- // ^%lpv_update_blk:
- // }
- // omp.yield
- // }
+ // omp.sections lastprivate(...) {
+ // omp.section {...}
+ // omp.section {...}
+ // omp.section {
+ // fir.allocate for `private`/`firstprivate`
+ // <More operations here>
+ // fir.if %true {
+ // ^%lpv_update_blk
+ // }
+ // }
+ // }
//
-
- // Only generate the compare once in presence of multiple LastPrivate
- // clauses.
- if (cmpCreated)
- continue;
- cmpCreated = true;
-
- mlir::Location loc = op->getLoc();
- mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
- firOpBuilder.setInsertionPoint(lastOper);
-
- mlir::Value iv = op->getRegion(0).front().getArguments()[0];
- mlir::Value ub =
- mlir::dyn_cast<mlir::omp::WsLoopOp>(op).getUpperBound()[0];
- mlir::Value step = mlir::dyn_cast<mlir::omp::WsLoopOp>(op).getStep()[0];
-
- // v = iv + step
- // cmp = step < 0 ? v < ub : v > ub
- mlir::Value v = firOpBuilder.create<mlir::arith::AddIOp>(loc, iv, step);
- mlir::Value zero =
- firOpBuilder.createIntegerConstant(loc, step.getType(), 0);
- mlir::Value negativeStep = firOpBuilder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::slt, step, zero);
- mlir::Value vLT = firOpBuilder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::slt, v, ub);
- mlir::Value vGT = firOpBuilder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::sgt, v, ub);
- mlir::Value cmpOp = firOpBuilder.create<mlir::arith::SelectOp>(
- loc, negativeStep, vLT, vGT);
-
- auto ifOp = firOpBuilder.create<fir::IfOp>(loc, cmpOp, /*else*/ false);
- firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- assert(loopIV && "loopIV was not set");
- firOpBuilder.create<fir::StoreOp>(op->getLoc(), v, loopIV);
- lastPrivIP = firOpBuilder.saveInsertionPoint();
- } else {
- TODO(converter.getCurrentLocation(),
- "lastprivate clause in constructs other than "
- "simd/worksharing-loop");
+ // To keep code consistency while handling privatization
+ // through this control flow, add a `fir.if` operation
+ // that always evaluates to true, in order to create
+ // a dedicated sub-region in `omp.section` where
+ // lastprivate FIR can reside. Later canonicalizations
+ // will optimize away this operation.
+ if (!eval.lowerAsUnstructured()) {
+ auto ifOp = firOpBuilder.create<fir::IfOp>(
+ op->getLoc(),
+ firOpBuilder.createIntegerConstant(
+ op->getLoc(), firOpBuilder.getIntegerType(1), 0x1),
+ /*else*/ false);
+ firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
+ const Fortran::parser::OpenMPConstruct *parentOmpConstruct =
+ eval.parentConstruct->getIf<Fortran::parser::OpenMPConstruct>();
+ assert(parentOmpConstruct &&
+ "Expected a valid enclosing OpenMP construct");
+ const Fortran::parser::OpenMPSectionsConstruct *sectionsConstruct =
+ std::get_if<Fortran::parser::OpenMPSectionsConstruct>(
+ &parentOmpConstruct->u);
+ assert(sectionsConstruct &&
+ "Expected an enclosing omp.sections construct");
+ const Fortran::parser::OmpClauseList §ionsEndClauseList =
+ std::get<Fortran::parser::OmpClauseList>(
+ std::get<Fortran::parser::OmpEndSectionsDirective>(
+ sectionsConstruct->t)
+ .t);
+ for (const Fortran::parser::OmpClause &otherClause :
+ sectionsEndClauseList.v)
+ if (std::get_if<Fortran::parser::OmpClause::Nowait>(&otherClause.u))
+ // Emit implicit barrier to synchronize threads and avoid data
+ // races on post-update of lastprivate variables when `nowait`
+ // clause is present.
+ firOpBuilder.create<mlir::omp::BarrierOp>(
+ converter.getCurrentLocation());
+ firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ lastPrivIP = firOpBuilder.saveInsertionPoint();
+ firOpBuilder.setInsertionPoint(ifOp);
+ insPt = firOpBuilder.saveInsertionPoint();
+ } else {
+ // Lastprivate operation is inserted at the end
+ // of the lexically last section in the sections
+ // construct
+ mlir::OpBuilder::InsertPoint unstructuredSectionsIP =
+ firOpBuilder.saveInsertionPoint();
+ mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
+ firOpBuilder.setInsertionPoint(lastOper);
+ lastPrivIP = firOpBuilder.saveInsertionPoint();
+ firOpBuilder.restoreInsertionPoint(unstructuredSectionsIP);
+ }
}
+ } else if (mlir::isa<mlir::omp::WsLoopOp>(op)) {
+ // Update the original variable just before exiting the worksharing
+ // loop. Conversion as follows:
+ //
+ // omp.wsloop {
+ // omp.wsloop { ...
+ // ... store
+ // store ===> %v = arith.addi %iv, %step
+ // omp.yield %cmp = %step < 0 ? %v < %ub : %v > %ub
+ // } fir.if %cmp {
+ // fir.store %v to %loopIV
+ // ^%lpv_update_blk:
+ // }
+ // omp.yield
+ // }
+ //
+
+ // Only generate the compare once in presence of multiple LastPrivate
+ // clauses.
+ if (cmpCreated)
+ continue;
+ cmpCreated = true;
+
+ mlir::Location loc = op->getLoc();
+ mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
+ firOpBuilder.setInsertionPoint(lastOper);
+
+ mlir::Value iv = op->getRegion(0).front().getArguments()[0];
+ mlir::Value ub =
+ mlir::dyn_cast<mlir::omp::WsLoopOp>(op).getUpperBound()[0];
+ mlir::Value step = mlir::dyn_cast<mlir::omp::WsLoopOp>(op).getStep()[0];
+
+ // v = iv + step
+ // cmp = step < 0 ? v < ub : v > ub
+ mlir::Value v = firOpBuilder.create<mlir::arith::AddIOp>(loc, iv, step);
+ mlir::Value zero =
+ firOpBuilder.createIntegerConstant(loc, step.getType(), 0);
+ mlir::Value negativeStep = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::slt, step, zero);
+ mlir::Value vLT = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::slt, v, ub);
+ mlir::Value vGT = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::sgt, v, ub);
+ mlir::Value cmpOp = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, negativeStep, vLT, vGT);
+
+ auto ifOp = firOpBuilder.create<fir::IfOp>(loc, cmpOp, /*else*/ false);
+ firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ assert(loopIV && "loopIV was not set");
+ firOpBuilder.create<fir::StoreOp>(op->getLoc(), v, loopIV);
+ lastPrivIP = firOpBuilder.saveInsertionPoint();
+ } else {
+ TODO(converter.getCurrentLocation(),
+ "lastprivate clause in constructs other than "
+ "simd/worksharing-loop");
}
}
firOpBuilder.restoreInsertionPoint(localInsPt);
@@ -1595,14 +1592,12 @@ void DataSharingProcessor::collectSymbols(
}
void DataSharingProcessor::collectDefaultSymbols() {
- for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
- if (const auto &defaultClause =
- std::get_if<Fortran::parser::OmpClause::Default>(&clause.u)) {
- if (defaultClause->v.v ==
- Fortran::parser::OmpDefaultClause::Type::Private)
+ for (const omp::Clause &clause : clauses) {
+ if (const auto *defaultClause =
+ std::get_if<omp::clause::Default>(&clause.u)) {
+ if (defaultClause->v == omp::clause::Default::Type::Private)
collectSymbols(Fortran::semantics::Symbol::Flag::OmpPrivate);
- else if (defaultClause->v.v ==
- Fortran::parser::OmpDefaultClause::Type::Firstprivate)
+ else if (defaultClause->v == omp::clause::Default::Type::Firstprivate)
collectSymbols(Fortran::semantics::Symbol::Flag::OmpFirstPrivate);
}
}
@@ -3447,7 +3442,7 @@ static void createBodyOfOp(Op &op, OpWithBodyGenInfo &info) {
std::optional<DataSharingProcessor> tempDsp;
if (privatize) {
if (!info.dsp) {
- tempDsp.emplace(info.converter, *info.clauses, info.eval);
+ tempDsp.emplace(info.converter, info.semaCtx, *info.clauses, info.eval);
tempDsp->processStep1();
}
}
@@ -4448,7 +4443,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpClauseList &loopOpClauseList,
mlir::Location loc) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- DataSharingProcessor dsp(converter, loopOpClauseList, eval);
+ DataSharingProcessor dsp(converter, semaCtx, loopOpClauseList, eval);
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
@@ -4505,7 +4500,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpClauseList *endClauseList,
mlir::Location loc) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- DataSharingProcessor dsp(converter, beginClauseList, eval);
+ DataSharingProcessor dsp(converter, semaCtx, beginClauseList, eval);
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
More information about the llvm-branch-commits
mailing list