[llvm-branch-commits] [flang] [flang][Lower] Convert OMP Map and related functions to evaluate::Expr (PR #81626)
Krzysztof Parzyszek via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Feb 19 11:55:14 PST 2024
https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/81626
>From 87437159da37749ad395d84a3fc1b729bd9e2480 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Thu, 8 Feb 2024 08:33:40 -0600
Subject: [PATCH] [flang][Lower] Convert OMP Map and related functions to
evaluate::Expr
The related functions are `gatherDataOperandAddrAndBounds` and
`genBoundsOps`. The former is used in OpenACC as well, and it was
updated to pass evaluate::Expr instead of parser objects.
The difference in the test case comes from unfolded conversions
of index expressions, which are explicitly of type integer(kind=8).
Delete now unused `findRepeatableClause2` and `findClause2`.
Add `AsGenericExpr` that takes std::optional. It already returns optional
Expr. Making it accept an optional Expr as input would reduce the number
of necessary checks when handling frequent optional values in evaluator.
---
flang/include/flang/Evaluate/tools.h | 8 +
flang/lib/Lower/DirectivesCommon.h | 389 ++++++++++++++++-----------
flang/lib/Lower/OpenACC.cpp | 54 ++--
flang/lib/Lower/OpenMP.cpp | 105 +++-----
4 files changed, 311 insertions(+), 245 deletions(-)
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index e9999974944e88..d5713cfe420a2e 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -148,6 +148,14 @@ inline Expr<SomeType> AsGenericExpr(Expr<SomeType> &&x) { return std::move(x); }
std::optional<Expr<SomeType>> AsGenericExpr(DataRef &&);
std::optional<Expr<SomeType>> AsGenericExpr(const Symbol &);
+// Propagate std::optional from input to output.
+template <typename A>
+std::optional<Expr<SomeType>> AsGenericExpr(std::optional<A> &&x) {
+ if (!x)
+ return std::nullopt;
+ return AsGenericExpr(std::move(*x));
+}
+
template <typename A>
common::IfNoLvalue<Expr<SomeKind<ResultType<A>::category>>, A> AsCategoryExpr(
A &&x) {
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 8d560db34e05bf..2fa90572bc63eb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -808,6 +808,75 @@ genBaseBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
return bounds;
}
+namespace detail {
+template <typename T> //
+static T &&AsRvalueRef(T &&t) {
+ return std::move(t);
+}
+template <typename T> //
+static T AsRvalueRef(T &t) {
+ return t;
+}
+template <typename T> //
+static T AsRvalueRef(const T &t) {
+ return t;
+}
+
+// Helper class for stripping enclosing parentheses and a conversion that
+// preserves type category. This is used for triplet elements, which are
+// always of type integer(kind=8). The lower/upper bounds are converted to
+// an "index" type, which is 64-bit, so the explicit conversion to kind=8
+// (if present) is not needed. When it's present, though, it causes generated
+// names to contain "int(..., kind=8)".
+struct PeelConvert {
+ template <Fortran::common::TypeCategory Category, int Kind>
+ static Fortran::semantics::MaybeExpr visit_with_category(
+ const Fortran::evaluate::Expr<Fortran::evaluate::Type<Category, Kind>>
+ &expr) {
+ return std::visit(
+ [](auto &&s) { return visit_with_category<Category, Kind>(s); },
+ expr.u);
+ }
+ template <Fortran::common::TypeCategory Category, int Kind>
+ static Fortran::semantics::MaybeExpr visit_with_category(
+ const Fortran::evaluate::Convert<Fortran::evaluate::Type<Category, Kind>,
+ Category> &expr) {
+ return AsGenericExpr(AsRvalueRef(expr.left()));
+ }
+ template <Fortran::common::TypeCategory Category, int Kind, typename T>
+ static Fortran::semantics::MaybeExpr visit_with_category(const T &) {
+ return std::nullopt; //
+ }
+ template <Fortran::common::TypeCategory Category, typename T>
+ static Fortran::semantics::MaybeExpr visit_with_category(const T &) {
+ return std::nullopt; //
+ }
+
+ template <Fortran::common::TypeCategory Category>
+ static Fortran::semantics::MaybeExpr
+ visit(const Fortran::evaluate::Expr<Fortran::evaluate::SomeKind<Category>>
+ &expr) {
+ return std::visit([](auto &&s) { return visit_with_category<Category>(s); },
+ expr.u);
+ }
+ static Fortran::semantics::MaybeExpr
+ visit(const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr) {
+ return std::visit([](auto &&s) { return visit(s); }, expr.u);
+ }
+ template <typename T> //
+ static Fortran::semantics::MaybeExpr visit(const T &) {
+ return std::nullopt;
+ }
+};
+
+static Fortran::semantics::SomeExpr
+peelOuterConvert(Fortran::semantics::SomeExpr &expr) {
+ if (auto peeled = PeelConvert::visit(expr))
+ return *peeled;
+ return expr;
+}
+} // namespace detail
+
/// Generate bounds operations for an array section when subscripts are
/// provided.
template <typename BoundsOp, typename BoundsType>
@@ -815,7 +884,7 @@ llvm::SmallVector<mlir::Value>
genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext &stmtCtx,
- const std::list<Fortran::parser::SectionSubscript> &subscripts,
+ const std::vector<Fortran::evaluate::Subscript> &subscripts,
std::stringstream &asFortran, fir::ExtendedValue &dataExv,
bool dataExvIsAssumedSize, AddrAndBoundsInfo &info,
bool treatIndexAsSection = false) {
@@ -828,8 +897,7 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
const int dataExvRank = static_cast<int>(dataExv.rank());
for (const auto &subscript : subscripts) {
- const auto *triplet{
- std::get_if<Fortran::parser::SubscriptTriplet>(&subscript.u)};
+ const auto *triplet{std::get_if<Fortran::evaluate::Triplet>(&subscript.u)};
if (triplet || treatIndexAsSection) {
if (dimension != 0)
asFortran << ',';
@@ -868,13 +936,18 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
strideInBytes = true;
}
- const Fortran::lower::SomeExpr *lower{nullptr};
+ Fortran::semantics::MaybeExpr lower;
if (triplet) {
- if (const auto &tripletLb{std::get<0>(triplet->t)})
- lower = Fortran::semantics::GetExpr(*tripletLb);
+ if ((lower = Fortran::evaluate::AsGenericExpr(triplet->lower())))
+ lower = detail::peelOuterConvert(*lower);
} else {
- const auto &index{std::get<Fortran::parser::IntExpr>(subscript.u)};
- lower = Fortran::semantics::GetExpr(index);
+ // Case of IndirectSubscriptIntegerExpr
+ using IndirectSubscriptIntegerExpr =
+ Fortran::evaluate::IndirectSubscriptIntegerExpr;
+ using SubscriptInteger = Fortran::evaluate::SubscriptInteger;
+ Fortran::evaluate::Expr<SubscriptInteger> oneInt =
+ std::get<IndirectSubscriptIntegerExpr>(subscript.u).value();
+ lower = Fortran::evaluate::AsGenericExpr(std::move(oneInt));
if (lower->Rank() > 0) {
mlir::emitError(
loc, "vector subscript cannot be used for an array section");
@@ -912,10 +985,12 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
extent = one;
} else {
asFortran << ':';
- const auto &upper{std::get<1>(triplet->t)};
+ Fortran::semantics::MaybeExpr upper =
+ Fortran::evaluate::AsGenericExpr(triplet->upper());
if (upper) {
- uval = Fortran::semantics::GetIntValue(upper);
+ upper = detail::peelOuterConvert(*upper);
+ uval = Fortran::evaluate::ToInt64(*upper);
if (uval) {
if (defaultLb) {
ubound = builder.createIntegerConstant(loc, idxTy, *uval - 1);
@@ -925,22 +1000,21 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
}
asFortran << *uval;
} else {
- const Fortran::lower::SomeExpr *uexpr =
- Fortran::semantics::GetExpr(*upper);
mlir::Value ub =
- fir::getBase(converter.genExprValue(loc, *uexpr, stmtCtx));
+ fir::getBase(converter.genExprValue(loc, *upper, stmtCtx));
ub = builder.createConvert(loc, baseLb.getType(), ub);
ubound = builder.create<mlir::arith::SubIOp>(loc, ub, baseLb);
- asFortran << uexpr->AsFortran();
+ asFortran << upper->AsFortran();
}
}
if (lower && upper) {
if (lval && uval && *uval < *lval) {
mlir::emitError(loc, "zero sized array section");
break;
- } else if (std::get<2>(triplet->t)) {
- const auto &strideExpr{std::get<2>(triplet->t)};
- if (strideExpr) {
+ } else {
+ // Stride is mandatory in evaluate::Triplet. Make sure it's 1.
+ auto val = Fortran::evaluate::ToInt64(triplet->GetStride());
+ if (!val || *val != 1) {
mlir::emitError(loc, "stride cannot be specified on "
"an array section");
break;
@@ -993,150 +1067,157 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
return bounds;
}
-template <typename ObjectType, typename BoundsOp, typename BoundsType>
+namespace detail {
+template <typename Ref, typename Expr> //
+std::optional<Ref> getRef(Expr &&expr) {
+ if constexpr (std::is_same_v<llvm::remove_cvref_t<Expr>,
+ Fortran::evaluate::DataRef>) {
+ if (auto *ref = std::get_if<Ref>(&expr.u))
+ return *ref;
+ return std::nullopt;
+ } else {
+ auto maybeRef = Fortran::evaluate::ExtractDataRef(expr);
+ if (!maybeRef || !std::holds_alternative<Ref>(maybeRef->u))
+ return std::nullopt;
+ return std::get<Ref>(maybeRef->u);
+ }
+}
+} // namespace detail
+
+template <typename BoundsOp, typename BoundsType>
AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder,
- Fortran::semantics::SemanticsContext &semanticsContext,
- Fortran::lower::StatementContext &stmtCtx, const ObjectType &object,
+ semantics::SemanticsContext &semaCtx,
+ Fortran::lower::StatementContext &stmtCtx,
+ Fortran::semantics::SymbolRef symbol,
+ const Fortran::semantics::MaybeExpr &maybeDesignator,
mlir::Location operandLocation, std::stringstream &asFortran,
llvm::SmallVector<mlir::Value> &bounds, bool treatIndexAsSection = false) {
+ using namespace Fortran;
+
AddrAndBoundsInfo info;
- std::visit(
- Fortran::common::visitors{
- [&](const Fortran::parser::Designator &designator) {
- if (auto expr{Fortran::semantics::AnalyzeExpr(semanticsContext,
- designator)}) {
- if (((*expr).Rank() > 0 || treatIndexAsSection) &&
- Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
- designator)) {
- const auto *arrayElement =
- Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
- designator);
- const auto *dataRef =
- std::get_if<Fortran::parser::DataRef>(&designator.u);
- fir::ExtendedValue dataExv;
- bool dataExvIsAssumedSize = false;
- if (Fortran::parser::Unwrap<
- Fortran::parser::StructureComponent>(
- arrayElement->base)) {
- auto exprBase = Fortran::semantics::AnalyzeExpr(
- semanticsContext, arrayElement->base);
- dataExv = converter.genExprAddr(operandLocation, *exprBase,
- stmtCtx);
- info.addr = fir::getBase(dataExv);
- info.rawInput = info.addr;
- asFortran << (*exprBase).AsFortran();
- } else {
- const Fortran::parser::Name &name =
- Fortran::parser::GetLastName(*dataRef);
- dataExvIsAssumedSize = Fortran::semantics::IsAssumedSizeArray(
- name.symbol->GetUltimate());
- info = getDataOperandBaseAddr(converter, builder,
- *name.symbol, operandLocation);
- dataExv = converter.getSymbolExtendedValue(*name.symbol);
- asFortran << name.ToString();
- }
-
- if (!arrayElement->subscripts.empty()) {
- asFortran << '(';
- bounds = genBoundsOps<BoundsOp, BoundsType>(
- builder, operandLocation, converter, stmtCtx,
- arrayElement->subscripts, asFortran, dataExv,
- dataExvIsAssumedSize, info, treatIndexAsSection);
- }
- asFortran << ')';
- } else if (auto structComp = Fortran::parser::Unwrap<
- Fortran::parser::StructureComponent>(designator)) {
- fir::ExtendedValue compExv =
- converter.genExprAddr(operandLocation, *expr, stmtCtx);
- info.addr = fir::getBase(compExv);
- info.rawInput = info.addr;
- if (fir::unwrapRefType(info.addr.getType())
- .isa<fir::SequenceType>())
- bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
- builder, operandLocation, converter, compExv,
- /*isAssumedSize=*/false);
- asFortran << (*expr).AsFortran();
-
- bool isOptional = Fortran::semantics::IsOptional(
- *Fortran::parser::GetLastName(*structComp).symbol);
- if (isOptional)
- info.isPresent = builder.create<fir::IsPresentOp>(
- operandLocation, builder.getI1Type(), info.rawInput);
-
- if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>(
- info.addr.getDefiningOp())) {
- if (fir::isAllocatableType(loadOp.getType()) ||
- fir::isPointerType(loadOp.getType()))
- info.addr = builder.create<fir::BoxAddrOp>(operandLocation,
- info.addr);
- info.rawInput = info.addr;
- }
-
- // If the component is an allocatable or pointer the result of
- // genExprAddr will be the result of a fir.box_addr operation or
- // a fir.box_addr has been inserted just before.
- // Retrieve the box so we handle it like other descriptor.
- if (auto boxAddrOp = mlir::dyn_cast_or_null<fir::BoxAddrOp>(
- info.addr.getDefiningOp())) {
- info.addr = boxAddrOp.getVal();
- info.rawInput = info.addr;
- bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
- builder, operandLocation, converter, compExv, info);
- }
- } else {
- if (Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
- designator)) {
- // Single array element.
- const auto *arrayElement =
- Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
- designator);
- (void)arrayElement;
- fir::ExtendedValue compExv =
- converter.genExprAddr(operandLocation, *expr, stmtCtx);
- info.addr = fir::getBase(compExv);
- info.rawInput = info.addr;
- asFortran << (*expr).AsFortran();
- } else if (const auto *dataRef{
- std::get_if<Fortran::parser::DataRef>(
- &designator.u)}) {
- // Scalar or full array.
- const Fortran::parser::Name &name =
- Fortran::parser::GetLastName(*dataRef);
- fir::ExtendedValue dataExv =
- converter.getSymbolExtendedValue(*name.symbol);
- info = getDataOperandBaseAddr(converter, builder,
- *name.symbol, operandLocation);
- if (fir::unwrapRefType(info.addr.getType())
- .isa<fir::BaseBoxType>()) {
- bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
- builder, operandLocation, converter, dataExv, info);
- }
- bool dataExvIsAssumedSize =
- Fortran::semantics::IsAssumedSizeArray(
- name.symbol->GetUltimate());
- if (fir::unwrapRefType(info.addr.getType())
- .isa<fir::SequenceType>())
- bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
- builder, operandLocation, converter, dataExv,
- dataExvIsAssumedSize);
- asFortran << name.ToString();
- } else { // Unsupported
- llvm::report_fatal_error(
- "Unsupported type of OpenACC operand");
- }
- }
- }
- },
- [&](const Fortran::parser::Name &name) {
- info = getDataOperandBaseAddr(converter, builder, *name.symbol,
- operandLocation);
- asFortran << name.ToString();
- }},
- object.u);
+
+ if (!maybeDesignator) {
+ info = getDataOperandBaseAddr(converter, builder, symbol, operandLocation);
+ asFortran << symbol->name().ToString();
+ return info;
+ }
+
+ semantics::SomeExpr designator = *maybeDesignator;
+
+ if ((designator.Rank() > 0 || treatIndexAsSection) &&
+ IsArrayElement(designator)) {
+ auto arrayRef = detail::getRef<evaluate::ArrayRef>(designator);
+ // This shouldn't fail after IsArrayElement(designator).
+ assert(arrayRef && "Expecting ArrayRef");
+
+ fir::ExtendedValue dataExv;
+ bool dataExvIsAssumedSize = false;
+
+ auto toMaybeExpr = [&](auto &&base) {
+ using BaseType = llvm::remove_cvref_t<decltype(base)>;
+ evaluate::ExpressionAnalyzer ea{semaCtx};
+
+ if constexpr (std::is_same_v<evaluate::NamedEntity, BaseType>) {
+ if (auto *ref = base.UnwrapSymbolRef())
+ return ea.Designate(evaluate::DataRef{*ref});
+ if (auto *ref = base.UnwrapComponent())
+ return ea.Designate(evaluate::DataRef{*ref});
+ llvm_unreachable("Unexpected NamedEntity");
+ } else {
+ static_assert(std::is_same_v<semantics::SymbolRef, BaseType>);
+ return ea.Designate(evaluate::DataRef{base});
+ }
+ };
+
+ auto arrayBase = toMaybeExpr(arrayRef->base());
+ assert(arrayBase);
+
+ if (detail::getRef<evaluate::Component>(*arrayBase)) {
+ dataExv = converter.genExprAddr(operandLocation, *arrayBase, stmtCtx);
+ info.addr = fir::getBase(dataExv);
+ info.rawInput = info.addr;
+ asFortran << arrayBase->AsFortran();
+ } else {
+ const semantics::Symbol &sym = arrayRef->GetLastSymbol();
+ dataExvIsAssumedSize =
+ Fortran::semantics::IsAssumedSizeArray(sym.GetUltimate());
+ info = getDataOperandBaseAddr(converter, builder, sym, operandLocation);
+ dataExv = converter.getSymbolExtendedValue(sym);
+ asFortran << sym.name().ToString();
+ }
+
+ if (!arrayRef->subscript().empty()) {
+ asFortran << '(';
+ bounds = genBoundsOps<BoundsOp, BoundsType>(
+ builder, operandLocation, converter, stmtCtx, arrayRef->subscript(),
+ asFortran, dataExv, dataExvIsAssumedSize, info, treatIndexAsSection);
+ }
+ asFortran << ')';
+ } else if (auto compRef = detail::getRef<evaluate::Component>(designator)) {
+ fir::ExtendedValue compExv =
+ converter.genExprAddr(operandLocation, designator, stmtCtx);
+ info.addr = fir::getBase(compExv);
+ info.rawInput = info.addr;
+ if (fir::unwrapRefType(info.addr.getType()).isa<fir::SequenceType>())
+ bounds = genBaseBoundsOps<BoundsOp, BoundsType>(builder, operandLocation,
+ converter, compExv,
+ /*isAssumedSize=*/false);
+ asFortran << designator.AsFortran();
+
+ if (semantics::IsOptional(compRef->GetLastSymbol())) {
+ info.isPresent = builder.create<fir::IsPresentOp>(
+ operandLocation, builder.getI1Type(), info.rawInput);
+ }
+
+ if (auto loadOp =
+ mlir::dyn_cast_or_null<fir::LoadOp>(info.addr.getDefiningOp())) {
+ if (fir::isAllocatableType(loadOp.getType()) ||
+ fir::isPointerType(loadOp.getType()))
+ info.addr = builder.create<fir::BoxAddrOp>(operandLocation, info.addr);
+ info.rawInput = info.addr;
+ }
+
+ // If the component is an allocatable or pointer the result of
+ // genExprAddr will be the result of a fir.box_addr operation or
+ // a fir.box_addr has been inserted just before.
+ // Retrieve the box so we handle it like other descriptor.
+ if (auto boxAddrOp =
+ mlir::dyn_cast_or_null<fir::BoxAddrOp>(info.addr.getDefiningOp())) {
+ info.addr = boxAddrOp.getVal();
+ info.rawInput = info.addr;
+ bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
+ builder, operandLocation, converter, compExv, info);
+ }
+ } else {
+ if (detail::getRef<evaluate::ArrayRef>(designator)) {
+ fir::ExtendedValue compExv =
+ converter.genExprAddr(operandLocation, designator, stmtCtx);
+ info.addr = fir::getBase(compExv);
+ info.rawInput = info.addr;
+ asFortran << designator.AsFortran();
+ } else if (auto symRef = detail::getRef<semantics::SymbolRef>(designator)) {
+ // Scalar or full array.
+ fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(*symRef);
+ info =
+ getDataOperandBaseAddr(converter, builder, *symRef, operandLocation);
+ if (fir::unwrapRefType(info.addr.getType()).isa<fir::BaseBoxType>()) {
+ bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
+ builder, operandLocation, converter, dataExv, info);
+ }
+ bool dataExvIsAssumedSize =
+ Fortran::semantics::IsAssumedSizeArray(symRef->get().GetUltimate());
+ if (fir::unwrapRefType(info.addr.getType()).isa<fir::SequenceType>())
+ bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
+ builder, operandLocation, converter, dataExv, dataExvIsAssumedSize);
+ asFortran << symRef->get().name().ToString();
+ } else { // Unsupported
+ llvm::report_fatal_error("Unsupported type of OpenACC operand");
+ }
+ }
+
return info;
}
-
} // namespace lower
} // namespace Fortran
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 6ae270f63f5cf4..a444682306ac20 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -269,6 +269,11 @@ getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) {
Fortran::parser::GetLastName(arrayElement->base);
return *name.symbol;
}
+ if (const auto *component =
+ Fortran::parser::Unwrap<Fortran::parser::StructureComponent>(
+ *designator)) {
+ return *component->component.symbol;
+ }
} else if (const auto *name =
std::get_if<Fortran::parser::Name>(&accObject.u)) {
return *name->symbol;
@@ -286,17 +291,20 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
mlir::acc::DataClause dataClause, bool structured,
bool implicit, bool setDeclareAttr = false) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
+ Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
+ Fortran::semantics::MaybeExpr designator =
+ std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
- mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
- stmtCtx, accObject, operandLocation,
- asFortran, bounds,
- /*treatIndexAsSection=*/true);
+ mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
+ converter, builder, semanticsContext, stmtCtx, symbol, designator,
+ operandLocation, asFortran, bounds,
+ /*treatIndexAsSection=*/true);
// If the input value is optional and is not a descriptor, we use the
// rawInput directly.
@@ -321,16 +329,19 @@ static void genDeclareDataOperandOperations(
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
mlir::acc::DataClause dataClause, bool structured, bool implicit) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
+ Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
+ Fortran::semantics::MaybeExpr designator =
+ std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
- mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
- stmtCtx, accObject, operandLocation,
- asFortran, bounds);
+ mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
+ converter, builder, semanticsContext, stmtCtx, symbol, designator,
+ operandLocation, asFortran, bounds);
EntryOp op = createDataEntryOp<EntryOp>(
builder, operandLocation, info.addr, asFortran, bounds, structured,
implicit, dataClause, info.addr.getType());
@@ -339,8 +350,7 @@ static void genDeclareDataOperandOperations(
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(info.addr.getType()))) {
mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion());
modBuilder.setInsertionPointAfter(builder.getFunction());
- std::string prefix =
- converter.mangleName(getSymbolFromAccObject(accObject));
+ std::string prefix = converter.mangleName(symbol);
createDeclareAllocFuncWithArg<EntryOp>(
modBuilder, builder, operandLocation, info.addr.getType(), prefix,
asFortran, dataClause);
@@ -783,16 +793,19 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
llvm::SmallVector<mlir::Attribute> &privatizations) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
+ Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
+ Fortran::semantics::MaybeExpr designator =
+ std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
- mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
- stmtCtx, accObject, operandLocation,
- asFortran, bounds);
+ mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
+ converter, builder, semanticsContext, stmtCtx, symbol, designator,
+ operandLocation, asFortran, bounds);
RecipeOp recipe;
mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType());
if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
@@ -1361,16 +1374,19 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
const auto &op =
std::get<Fortran::parser::AccReductionOperator>(objectList.t);
mlir::acc::ReductionOperator mlirOp = getReductionOperator(op);
+ Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objects.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
+ Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
+ Fortran::semantics::MaybeExpr designator =
+ std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
- mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
- stmtCtx, accObject, operandLocation,
- asFortran, bounds);
+ mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
+ converter, builder, semanticsContext, stmtCtx, symbol, designator,
+ operandLocation, asFortran, bounds);
mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType());
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index caae5c0cef9251..4309d69434839f 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1789,18 +1789,6 @@ class ClauseProcessor {
return end;
}
- /// Utility to find a clause within a range in the clause list.
- template <typename T>
- static ClauseIterator2 findClause2(ClauseIterator2 begin,
- ClauseIterator2 end) {
- for (ClauseIterator2 it = begin; it != end; ++it) {
- if (std::get_if<T>(&it->u))
- return it;
- }
-
- return end;
- }
-
/// Return the first instance of the given clause found in the clause list or
/// `nullptr` if not present. If more than one instance is expected, use
/// `findRepeatableClause` instead.
@@ -1836,26 +1824,6 @@ class ClauseProcessor {
return found;
}
- /// Call `callbackFn` for each occurrence of the given clause. Return `true`
- /// if at least one instance was found.
- template <typename T>
- bool findRepeatableClause2(
- std::function<void(const T *, const Fortran::parser::CharBlock &source)>
- callbackFn) const {
- bool found = false;
- ClauseIterator2 nextIt, endIt = clauses2.v.end();
- for (ClauseIterator2 it = clauses2.v.begin(); it != endIt; it = nextIt) {
- nextIt = findClause2<T>(it, endIt);
-
- if (nextIt != endIt) {
- callbackFn(&std::get<T>(nextIt->u), nextIt->source);
- found = true;
- ++nextIt;
- }
- }
- return found;
- }
-
/// Set the `result` to a new `mlir::UnitAttr` if the clause is present.
template <typename T>
bool markClauseOccurrence(mlir::UnitAttr &result) const {
@@ -2958,65 +2926,61 @@ bool ClauseProcessor::processMap(
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- return findRepeatableClause2<ClauseTy::Map>(
- [&](const ClauseTy::Map *mapClause,
+ return findRepeatableClause<omp::clause::Map>(
+ [&](const omp::clause::Map &clause,
const Fortran::parser::CharBlock &source) {
+ using Map = omp::clause::Map;
mlir::Location clauseLocation = converter.genLocation(source);
- const auto &oMapType =
- std::get<std::optional<Fortran::parser::OmpMapType>>(
- mapClause->v.t);
+ const auto &oMapType = std::get<std::optional<Map::MapType>>(clause.t);
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
// If the map type is specified, then process it else Tofrom is the
// default.
if (oMapType) {
- const Fortran::parser::OmpMapType::Type &mapType =
- std::get<Fortran::parser::OmpMapType::Type>(oMapType->t);
+ const Map::MapType::Type &mapType =
+ std::get<Map::MapType::Type>(oMapType->t);
switch (mapType) {
- case Fortran::parser::OmpMapType::Type::To:
+ case Map::MapType::Type::To:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
break;
- case Fortran::parser::OmpMapType::Type::From:
+ case Map::MapType::Type::From:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
break;
- case Fortran::parser::OmpMapType::Type::Tofrom:
+ case Map::MapType::Type::Tofrom:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
break;
- case Fortran::parser::OmpMapType::Type::Alloc:
- case Fortran::parser::OmpMapType::Type::Release:
+ case Map::MapType::Type::Alloc:
+ case Map::MapType::Type::Release:
// alloc and release is the default map_type for the Target Data
// Ops, i.e. if no bits for map_type is supplied then alloc/release
// is implicitly assumed based on the target directive. Default
// value for Target Data and Enter Data is alloc and for Exit Data
// it is release.
break;
- case Fortran::parser::OmpMapType::Type::Delete:
+ case Map::MapType::Type::Delete:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
}
- if (std::get<std::optional<Fortran::parser::OmpMapType::Always>>(
- oMapType->t))
+ if (std::get<std::optional<Map::MapType::Always>>(oMapType->t))
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
} else {
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
}
- for (const Fortran::parser::OmpObject &ompObject :
- std::get<Fortran::parser::OmpObjectList>(mapClause->v.t).v) {
+ for (const omp::Object &object : std::get<omp::ObjectList>(clause.t)) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::OmpObject, mlir::omp::DataBoundsOp,
- mlir::omp::DataBoundsType>(
- converter, firOpBuilder, semaCtx, stmtCtx, ompObject,
- clauseLocation, asFortran, bounds, treatIndexAsSection);
+ mlir::omp::DataBoundsOp, mlir::omp::DataBoundsType>(
+ converter, firOpBuilder, semaCtx, stmtCtx, *object.sym,
+ object.dsg, clauseLocation, asFortran, bounds,
+ treatIndexAsSection);
- auto origSymbol =
- converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
+ auto origSymbol = converter.getSymbolAddress(*object.sym);
mlir::Value symAddr = info.addr;
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
symAddr = origSymbol;
@@ -3039,7 +3003,7 @@ bool ClauseProcessor::processMap(
mapSymLocs->push_back(symAddr.getLoc());
if (mapSymbols)
- mapSymbols->push_back(getOmpObjectSymbol(ompObject));
+ mapSymbols->push_back(object.sym);
}
});
}
@@ -3120,32 +3084,31 @@ template <typename T>
bool ClauseProcessor::processMotionClauses(
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
- return findRepeatableClause2<T>(
- [&](const T *motionClause, const Fortran::parser::CharBlock &source) {
+ return findRepeatableClause<T>(
+ [&](const T &clause, const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- static_assert(std::is_same_v<T, ClauseProcessor::ClauseTy::To> ||
- std::is_same_v<T, ClauseProcessor::ClauseTy::From>);
+ static_assert(std::is_same_v<T, omp::clause::To> ||
+ std::is_same_v<T, omp::clause::From>);
// TODO Support motion modifiers: present, mapper, iterator.
constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
- std::is_same_v<T, ClauseProcessor::ClauseTy::To>
+ std::is_same_v<T, omp::clause::To>
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
- for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) {
+ for (const omp::Object &object : clause.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
- Fortran::parser::OmpObject, mlir::omp::DataBoundsOp,
- mlir::omp::DataBoundsType>(
- converter, firOpBuilder, semaCtx, stmtCtx, ompObject,
- clauseLocation, asFortran, bounds, treatIndexAsSection);
+ mlir::omp::DataBoundsOp, mlir::omp::DataBoundsType>(
+ converter, firOpBuilder, semaCtx, stmtCtx, *object.sym,
+ object.dsg, clauseLocation, asFortran, bounds,
+ treatIndexAsSection);
- auto origSymbol =
- converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
+ auto origSymbol = converter.getSymbolAddress(*object.sym);
mlir::Value symAddr = info.addr;
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
symAddr = origSymbol;
@@ -3899,10 +3862,8 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
cp.processNowait(nowaitAttr);
if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
- cp.processMotionClauses<Fortran::parser::OmpClause::To>(stmtCtx,
- mapOperands);
- cp.processMotionClauses<Fortran::parser::OmpClause::From>(stmtCtx,
- mapOperands);
+ cp.processMotionClauses<omp::clause::To>(stmtCtx, mapOperands);
+ cp.processMotionClauses<omp::clause::From>(stmtCtx, mapOperands);
} else {
cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
More information about the llvm-branch-commits
mailing list