[flang-commits] [flang] 8256867 - [flang] Fold FINDLOC()
peter klausler via flang-commits
flang-commits at lists.llvm.org
Thu Sep 30 12:08:18 PDT 2021
Author: Peter Klausler
Date: 2021-09-30T12:08:10-07:00
New Revision: 82568675087eebbe4f1c5c56f40969c43a738b83
URL: https://github.com/llvm/llvm-project/commit/82568675087eebbe4f1c5c56f40969c43a738b83
DIFF: https://github.com/llvm/llvm-project/commit/82568675087eebbe4f1c5c56f40969c43a738b83.diff
LOG: [flang] Fold FINDLOC()
Fold the transformational intrinsic function FINDLOC() for
all combinations of optional arguments and data types.
Differential Revision: https://reviews.llvm.org/D110757
Added:
flang/test/Evaluate/folding30.f90
Modified:
flang/include/flang/Evaluate/constant.h
flang/include/flang/Parser/provenance.h
flang/lib/Evaluate/constant.cpp
flang/lib/Evaluate/fold-character.cpp
flang/lib/Evaluate/fold-integer.cpp
flang/lib/Evaluate/fold-logical.cpp
flang/lib/Evaluate/fold-reduction.cpp
flang/lib/Evaluate/fold-reduction.h
flang/lib/Evaluate/shape.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Evaluate/constant.h b/flang/include/flang/Evaluate/constant.h
index ea89e4ee2ee23..89b98f389b629 100644
--- a/flang/include/flang/Evaluate/constant.h
+++ b/flang/include/flang/Evaluate/constant.h
@@ -65,6 +65,7 @@ class ConstantBounds {
const ConstantSubscripts &shape() const { return shape_; }
const ConstantSubscripts &lbounds() const { return lbounds_; }
void set_lbounds(ConstantSubscripts &&);
+ void SetLowerBoundsToOne();
int Rank() const { return GetRank(shape_); }
Constant<SubscriptInteger> SHAPE() const;
@@ -140,8 +141,8 @@ template <typename T> class Constant : public ConstantBase<T> {
}
}
- // Apply subscripts. An empty subscript list is allowed for
- // a scalar constant.
+ // Apply subscripts. Excess subscripts are ignored, including the
+ // case of a scalar.
Element At(const ConstantSubscripts &) const;
Constant Reshape(ConstantSubscripts &&) const;
diff --git a/flang/include/flang/Parser/provenance.h b/flang/include/flang/Parser/provenance.h
index 7ff475a2316ab..4ada5c81d6bac 100644
--- a/flang/include/flang/Parser/provenance.h
+++ b/flang/include/flang/Parser/provenance.h
@@ -30,7 +30,7 @@ namespace Fortran::parser {
// Each character in the contiguous source stream built by the
// prescanner corresponds to a particular character in a source file,
-// include file, macro expansion, or compiler-inserted padding.
+// include file, macro expansion, or compiler-inserted text.
// The location of this original character to which a parsable character
// corresponds is its provenance.
//
diff --git a/flang/lib/Evaluate/constant.cpp b/flang/lib/Evaluate/constant.cpp
index 8f30ca0811626..01f5657360eb4 100644
--- a/flang/lib/Evaluate/constant.cpp
+++ b/flang/lib/Evaluate/constant.cpp
@@ -14,15 +14,6 @@
namespace Fortran::evaluate {
-std::size_t TotalElementCount(const ConstantSubscripts &shape) {
- std::size_t size{1};
- for (auto dim : shape) {
- CHECK(dim >= 0);
- size *= dim;
- }
- return size;
-}
-
ConstantBounds::ConstantBounds(const ConstantSubscripts &shape)
: shape_(shape), lbounds_(shape_.size(), 1) {}
@@ -36,6 +27,12 @@ void ConstantBounds::set_lbounds(ConstantSubscripts &&lb) {
lbounds_ = std::move(lb);
}
+void ConstantBounds::SetLowerBoundsToOne() {
+ for (auto &n : lbounds_) {
+ n = 1;
+ }
+}
+
Constant<SubscriptInteger> ConstantBounds::SHAPE() const {
return AsConstantShape(shape_);
}
@@ -55,6 +52,10 @@ ConstantSubscript ConstantBounds::SubscriptsToOffset(
return offset;
}
+std::size_t TotalElementCount(const ConstantSubscripts &shape) {
+ return static_cast<std::size_t>(GetSize(shape));
+}
+
bool ConstantBounds::IncrementSubscripts(
ConstantSubscripts &indices, const std::vector<int> *dimOrder) const {
int rank{GetRank(shape_)};
diff --git a/flang/lib/Evaluate/fold-character.cpp b/flang/lib/Evaluate/fold-character.cpp
index 4fc37aa8e2d3d..7c37e4cb011fb 100644
--- a/flang/lib/Evaluate/fold-character.cpp
+++ b/flang/lib/Evaluate/fold-character.cpp
@@ -102,7 +102,7 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
CharacterUtils<KIND>::TRIM(std::get<Scalar<T>>(*scalar))}};
}
}
- // TODO: findloc, maxloc, minloc, transfer
+ // TODO: maxloc, minloc, transfer
return Expr<T>{std::move(funcRef)};
}
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index c69ce32da1882..441a5a9924f98 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -182,14 +182,8 @@ static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
if (const Constant<LogicalResult> *mask{arg.empty()
? nullptr
: Folder<LogicalResult>{context}.Folding(arg[0])}) {
- std::optional<ConstantSubscript> dim;
- if (arg.size() > 1 && arg[1]) {
- dim = CheckDIM(context, arg[1], mask->Rank());
- if (!dim) {
- mask = nullptr;
- }
- }
- if (mask) {
+ std::optional<int> dim;
+ if (CheckReductionDIM(dim, context, arg, 1, mask->Rank())) {
auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
if (mask->At(at).IsTrue()) {
element = element.AddSigned(Scalar<T>{1}).value;
@@ -201,13 +195,159 @@ static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
return Expr<T>{std::move(ref)};
}
+// FINDLOC()
+class FindlocHelper {
+public:
+ FindlocHelper(
+ DynamicType &&type, ActualArguments &arg, FoldingContext &context)
+ : type_{type}, arg_{arg}, context_{context} {}
+ using Result = std::optional<Constant<SubscriptInteger>>;
+ using Types = AllIntrinsicTypes;
+
+ template <typename T> Result Test() const {
+ if (T::category != type_.category() || T::kind != type_.kind()) {
+ return std::nullopt;
+ }
+ CHECK(arg_.size() == 6);
+ Folder<T> folder{context_};
+ Constant<T> *array{folder.Folding(arg_[0])};
+ Constant<T> *value{folder.Folding(arg_[1])};
+ if (!array || !value) {
+ return std::nullopt;
+ }
+ std::optional<int> dim;
+ Constant<LogicalResult> *mask{
+ GetReductionMASK(arg_[3], array->shape(), context_)};
+ if ((!mask && arg_[3]) ||
+ !CheckReductionDIM(dim, context_, arg_, 2, array->Rank())) {
+ return std::nullopt;
+ }
+ bool back{false};
+ if (arg_[5]) {
+ const auto *backConst{Folder<LogicalResult>{context_}.Folding(arg_[5])};
+ if (backConst) {
+ back = backConst->GetScalarValue().value().IsTrue();
+ } else {
+ return std::nullopt;
+ }
+ }
+ // Use lower bounds of 1 exclusively.
+ array->SetLowerBoundsToOne();
+ ConstantSubscripts at{array->lbounds()}, maskAt, resultIndices, resultShape;
+ if (mask) {
+ mask->SetLowerBoundsToOne();
+ maskAt = mask->lbounds();
+ }
+ if (dim) { // DIM=
+ if (*dim < 1 || *dim > array->Rank()) {
+ context_.messages().Say(
+ "FINDLOC(DIM=%d) is out of range"_err_en_US, *dim);
+ return std::nullopt;
+ }
+ int zbDim{*dim - 1};
+ resultShape = array->shape();
+ resultShape.erase(
+ resultShape.begin() + zbDim); // scalar if array is vector
+ ConstantSubscript dimLength{array->shape()[zbDim]};
+ ConstantSubscript n{GetSize(resultShape)};
+ for (ConstantSubscript j{0}; j < n; ++j) {
+ ConstantSubscript hit{array->lbounds()[zbDim] - 1};
+ for (ConstantSubscript k{0}; k < dimLength;
+ ++k, ++at[zbDim], mask && ++maskAt[zbDim]) {
+ if ((!mask || mask->At(maskAt).IsTrue()) &&
+ IsHit(array->At(at), *value)) {
+ hit = at[zbDim];
+ if (!back) {
+ break;
+ }
+ }
+ }
+ resultIndices.emplace_back(hit);
+ at[zbDim] = array->lbounds()[zbDim] + dimLength - 1;
+ array->IncrementSubscripts(at);
+ at[zbDim] = array->lbounds()[zbDim];
+ if (mask) {
+ maskAt[zbDim] = mask->lbounds()[zbDim] + dimLength - 1;
+ mask->IncrementSubscripts(maskAt);
+ maskAt[zbDim] = mask->lbounds()[zbDim];
+ }
+ }
+ } else { // no DIM=
+ resultShape = ConstantSubscripts{array->Rank()}; // always a vector
+ ConstantSubscript n{GetSize(array->shape())};
+ resultIndices = ConstantSubscripts(array->Rank(), 0);
+ for (ConstantSubscript j{0}; j < n; ++j, array->IncrementSubscripts(at),
+ mask && mask->IncrementSubscripts(maskAt)) {
+ if ((!mask || mask->At(maskAt).IsTrue()) &&
+ IsHit(array->At(at), *value)) {
+ resultIndices = at;
+ if (!back) {
+ break;
+ }
+ }
+ }
+ }
+ std::vector<Scalar<SubscriptInteger>> resultElements;
+ for (ConstantSubscript j : resultIndices) {
+ resultElements.emplace_back(j);
+ }
+ return Constant<SubscriptInteger>{
+ std::move(resultElements), std::move(resultShape)};
+ }
+
+private:
+ template <typename T>
+ bool IsHit(typename Constant<T>::Element element, Constant<T> value) const {
+ std::optional<Expr<LogicalResult>> cmp;
+ if constexpr (T::category == TypeCategory::Logical) {
+ // array(at) .EQV. value?
+ cmp.emplace(
+ ConvertToType<LogicalResult>(Expr<T>{LogicalOperation<T::kind>{
+ LogicalOperator::Eqv, Expr<T>{Constant<T>{std::move(element)}},
+ Expr<T>{std::move(value)}}}));
+ } else { // array(at) .EQ. value?
+ cmp.emplace(PackageRelation(RelationalOperator::EQ,
+ Expr<T>{Constant<T>{std::move(element)}}, Expr<T>{std::move(value)}));
+ }
+ Expr<LogicalResult> folded{Fold(context_, std::move(*cmp))};
+ return GetScalarConstantValue<LogicalResult>(folded).value().IsTrue();
+ }
+
+ DynamicType type_;
+ ActualArguments &arg_;
+ FoldingContext &context_;
+};
+
+static std::optional<Constant<SubscriptInteger>> FoldFindlocCall(
+ ActualArguments &arg, FoldingContext &context) {
+ CHECK(arg.size() == 6);
+ if (arg[0]) {
+ if (auto type{arg[0]->GetType()}) {
+ return common::SearchTypes(FindlocHelper{std::move(*type), arg, context});
+ }
+ }
+ return std::nullopt;
+}
+
+template <typename T>
+static Expr<T> FoldFindloc(FoldingContext &context, FunctionRef<T> &&ref) {
+ static_assert(T::category == TypeCategory::Integer);
+ if (std::optional<Constant<SubscriptInteger>> found{
+ FoldFindlocCall(ref.arguments(), context)}) {
+ return Expr<T>{Fold(
+ context, ConvertToType<T>(Expr<SubscriptInteger>{std::move(*found)}))};
+ } else {
+ return Expr<T>{std::move(ref)};
+ }
+}
+
// for IALL, IANY, & IPARITY
template <typename T>
static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
Scalar<T> identity) {
static_assert(T::category == TypeCategory::Integer);
- std::optional<ConstantSubscript> dim;
+ std::optional<int> dim;
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
@@ -310,6 +450,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
} else {
DIE("exponent argument must be real");
}
+ } else if (name == "findloc") {
+ return FoldFindloc<T>(context, std::move(funcRef));
} else if (name == "huge") {
return Expr<T>{Scalar<T>::HUGE()};
} else if (name == "iachar" || name == "ichar") {
@@ -711,7 +853,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
} else if (name == "ubound") {
return UBOUND(context, std::move(funcRef));
}
- // TODO: dot_product, findloc, ibits, image_status, ishftc,
+ // TODO: dot_product, ibits, image_status, ishftc,
// matmul, maxloc, minloc, sign, transfer
return Expr<T>{std::move(funcRef)};
}
diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp
index 71a8f701991bd..1e11fec256405 100644
--- a/flang/lib/Evaluate/fold-logical.cpp
+++ b/flang/lib/Evaluate/fold-logical.cpp
@@ -19,7 +19,7 @@ static Expr<T> FoldAllAny(FoldingContext &context, FunctionRef<T> &&ref,
Scalar<T> identity) {
static_assert(T::category == TypeCategory::Logical);
using Element = Scalar<T>;
- std::optional<ConstantSubscript> dim;
+ std::optional<int> dim;
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
diff --git a/flang/lib/Evaluate/fold-reduction.cpp b/flang/lib/Evaluate/fold-reduction.cpp
index f171f859dc064..56f4b70b4f667 100644
--- a/flang/lib/Evaluate/fold-reduction.cpp
+++ b/flang/lib/Evaluate/fold-reduction.cpp
@@ -9,24 +9,39 @@
#include "fold-reduction.h"
namespace Fortran::evaluate {
-
-std::optional<ConstantSubscript> CheckDIM(
- FoldingContext &context, std::optional<ActualArgument> &arg, int rank) {
- if (arg) {
- if (auto *dimConst{Folder<SubscriptInteger>{context}.Folding(arg)}) {
+bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &context,
+ ActualArguments &arg, std::optional<int> dimIndex, int rank) {
+ if (dimIndex && static_cast<std::size_t>(*dimIndex) < arg.size()) {
+ if (auto *dimConst{
+ Folder<SubscriptInteger>{context}.Folding(arg[*dimIndex])}) {
if (auto dimScalar{dimConst->GetScalarValue()}) {
- auto dim{dimScalar->ToInt64()};
- if (dim >= 1 && dim <= rank) {
- return {dim};
+ auto dimVal{dimScalar->ToInt64()};
+ if (dimVal >= 1 && dimVal <= rank) {
+ dim = dimVal;
} else {
context.messages().Say(
"DIM=%jd is not valid for an array of rank %d"_err_en_US,
- static_cast<std::intmax_t>(dim), rank);
+ static_cast<std::intmax_t>(dimVal), rank);
+ return false;
}
}
}
}
- return std::nullopt;
+ return true;
}
+Constant<LogicalResult> *GetReductionMASK(
+ std::optional<ActualArgument> &maskArg, const ConstantSubscripts &shape,
+ FoldingContext &context) {
+ Constant<LogicalResult> *mask{
+ Folder<LogicalResult>{context}.Folding(maskArg)};
+ if (mask &&
+ !CheckConformance(context.messages(), AsShape(shape),
+ AsShape(mask->shape()), CheckConformanceFlags::RightScalarExpandable,
+ "ARRAY=", "MASK=")
+ .value_or(false)) {
+ mask = nullptr;
+ }
+ return mask;
+}
} // namespace Fortran::evaluate
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index 714de7c2b51d5..bfa4a1f80c79b 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-// TODO: DOT_PRODUCT, FINDLOC, NORM2, MAXLOC, MINLOC, PARITY
+// TODO: DOT_PRODUCT, NORM2, MAXLOC, MINLOC, PARITY
#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
@@ -15,9 +15,15 @@
namespace Fortran::evaluate {
-// Folds & validates a DIM= actual argument.
-std::optional<ConstantSubscript> CheckDIM(
- FoldingContext &, std::optional<ActualArgument> &, int rank);
+// Fold and validate a DIM= argument. Returns false on error.
+bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &,
+ ActualArguments &, std::optional<int> dimIndex, int rank);
+
+// Fold and validate a MASK= argument. Return null on error, absent MASK=, or
+// non-constant MASK=.
+Constant<LogicalResult> *GetReductionMASK(
+ std::optional<ActualArgument> &maskArg, const ConstantSubscripts &shape,
+ FoldingContext &);
// Common preprocessing for reduction transformational intrinsic function
// folding. If the intrinsic can have DIM= &/or MASK= arguments, extract
@@ -26,10 +32,9 @@ std::optional<ConstantSubscript> CheckDIM(
// the mask. If the result is present, the intrinsic call can be folded.
template <typename T>
static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
- ActualArguments &arg, std::optional<ConstantSubscript> &dim,
- const Scalar<T> &identity, int arrayIndex,
- std::optional<std::size_t> dimIndex = std::nullopt,
- std::optional<std::size_t> maskIndex = std::nullopt) {
+ ActualArguments &arg, std::optional<int> &dim, const Scalar<T> &identity,
+ int arrayIndex, std::optional<int> dimIndex = std::nullopt,
+ std::optional<int> maskIndex = std::nullopt) {
if (arg.empty()) {
return std::nullopt;
}
@@ -37,46 +42,37 @@ static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
if (!folded || folded->Rank() < 1) {
return std::nullopt;
}
- if (dimIndex && arg.size() >= *dimIndex + 1 && arg[*dimIndex]) {
- dim = CheckDIM(context, arg[*dimIndex], folded->Rank());
- if (!dim) {
- return std::nullopt;
- }
+ if (!CheckReductionDIM(dim, context, arg, dimIndex, folded->Rank())) {
+ return std::nullopt;
}
- if (maskIndex && arg.size() >= *maskIndex + 1 && arg[*maskIndex]) {
- if (Constant<LogicalResult> *
- mask{Folder<LogicalResult>{context}.Folding(arg[*maskIndex])}) {
- if (CheckConformance(context.messages(), AsShape(folded->shape()),
- AsShape(mask->shape()),
- CheckConformanceFlags::RightScalarExpandable, "ARRAY=", "MASK=")
- .value_or(false)) {
- // Apply the mask in place to the array
- std::size_t n{folded->size()};
- std::vector<typename Constant<T>::Element> elements;
- if (auto scalarMask{mask->GetScalarValue()}) {
- if (scalarMask->IsTrue()) {
- return Constant<T>{*folded};
- } else { // MASK=.FALSE.
- elements = std::vector<typename Constant<T>::Element>(n, identity);
- }
- } else { // mask is an array; test its elements
+ if (maskIndex && static_cast<std::size_t>(*maskIndex) < arg.size() &&
+ arg[*maskIndex]) {
+ if (const Constant<LogicalResult> *mask{
+ GetReductionMASK(arg[*maskIndex], folded->shape(), context)}) {
+ // Apply the mask in place to the array
+ std::size_t n{folded->size()};
+ std::vector<typename Constant<T>::Element> elements;
+ if (auto scalarMask{mask->GetScalarValue()}) {
+ if (scalarMask->IsTrue()) {
+ return Constant<T>{*folded};
+ } else { // MASK=.FALSE.
elements = std::vector<typename Constant<T>::Element>(n, identity);
- ConstantSubscripts at{folded->lbounds()};
- for (std::size_t j{0}; j < n; ++j, folded->IncrementSubscripts(at)) {
- if (mask->values()[j].IsTrue()) {
- elements[j] = folded->At(at);
- }
- }
}
- if constexpr (T::category == TypeCategory::Character) {
- return Constant<T>{static_cast<ConstantSubscript>(identity.size()),
- std::move(elements), ConstantSubscripts{folded->shape()}};
- } else {
- return Constant<T>{
- std::move(elements), ConstantSubscripts{folded->shape()}};
+ } else { // mask is an array; test its elements
+ elements = std::vector<typename Constant<T>::Element>(n, identity);
+ ConstantSubscripts at{folded->lbounds()};
+ for (std::size_t j{0}; j < n; ++j, folded->IncrementSubscripts(at)) {
+ if (mask->values()[j].IsTrue()) {
+ elements[j] = folded->At(at);
+ }
}
+ }
+ if constexpr (T::category == TypeCategory::Character) {
+ return Constant<T>{static_cast<ConstantSubscript>(identity.size()),
+ std::move(elements), ConstantSubscripts{folded->shape()}};
} else {
- return std::nullopt;
+ return Constant<T>{
+ std::move(elements), ConstantSubscripts{folded->shape()}};
}
} else {
return std::nullopt;
@@ -90,7 +86,7 @@ static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
// or to a scalar (w/o DIM=).
template <typename T, typename ACCUMULATOR, typename ARRAY>
static Constant<T> DoReduction(const Constant<ARRAY> &array,
- std::optional<ConstantSubscript> &dim, const Scalar<T> &identity,
+ std::optional<int> &dim, const Scalar<T> &identity,
ACCUMULATOR &accumulator) {
ConstantSubscripts at{array.lbounds()};
std::vector<typename Constant<T>::Element> elements;
@@ -132,7 +128,7 @@ static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
T::category == TypeCategory::Real ||
T::category == TypeCategory::Character);
using Element = Scalar<T>;
- std::optional<ConstantSubscript> dim;
+ std::optional<int> dim;
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
@@ -159,7 +155,7 @@ static Expr<T> FoldProduct(
T::category == TypeCategory::Real ||
T::category == TypeCategory::Complex);
using Element = typename Constant<T>::Element;
- std::optional<ConstantSubscript> dim;
+ std::optional<int> dim;
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
@@ -192,7 +188,7 @@ static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
T::category == TypeCategory::Real ||
T::category == TypeCategory::Complex);
using Element = typename Constant<T>::Element;
- std::optional<ConstantSubscript> dim;
+ std::optional<int> dim;
Element identity{}, correction{};
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
diff --git a/flang/lib/Evaluate/shape.cpp b/flang/lib/Evaluate/shape.cpp
index 0d2bb504a84ad..a387845625f2f 100644
--- a/flang/lib/Evaluate/shape.cpp
+++ b/flang/lib/Evaluate/shape.cpp
@@ -208,7 +208,8 @@ MaybeExtentExpr GetSize(Shape &&shape) {
ConstantSubscript GetSize(const ConstantSubscripts &shape) {
ConstantSubscript size{1};
- for (auto dim : std::move(shape)) {
+ for (auto dim : shape) {
+ CHECK(dim >= 0);
size *= dim;
}
return size;
@@ -652,14 +653,15 @@ auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
}
}
}
- } else if (intrinsic->name == "maxloc" || intrinsic->name == "minloc") {
- // TODO: FINDLOC
- if (call.arguments().size() >= 2) {
+ } else if (intrinsic->name == "findloc" || intrinsic->name == "maxloc" ||
+ intrinsic->name == "minloc") {
+ std::size_t dimIndex{intrinsic->name == "findloc" ? 2u : 1u};
+ if (call.arguments().size() > dimIndex) {
if (auto arrayShape{
(*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))}) {
auto rank{static_cast<int>(arrayShape->size())};
if (const auto *dimArg{
- UnwrapExpr<Expr<SomeType>>(call.arguments()[1])}) {
+ UnwrapExpr<Expr<SomeType>>(call.arguments()[dimIndex])}) {
auto dim{ToInt64(*dimArg)};
if (dim && *dim >= 1 && *dim <= rank) {
arrayShape->erase(arrayShape->begin() + (*dim - 1));
diff --git a/flang/test/Evaluate/folding30.f90 b/flang/test/Evaluate/folding30.f90
new file mode 100644
index 0000000000000..748723c08e88d
--- /dev/null
+++ b/flang/test/Evaluate/folding30.f90
@@ -0,0 +1,21 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of FINDLOC
+module m1
+ integer, parameter :: ia1(2:6) = [1, 2, 3, 2, 1]
+ integer, parameter :: ia2(2:3,2:4) = reshape([1, 2, 3, 3, 2, 1], shape(ia2))
+
+ logical, parameter :: ti1a = all(findloc(ia1, 1) == 1)
+ logical, parameter :: ti1ar = rank(findloc(ia1, 1)) == 1
+ logical, parameter :: ti1ak = kind(findloc(ia1, 1, kind=2)) == 2
+ logical, parameter :: ti1ad = findloc(ia1, 1, dim=1) == 1
+ logical, parameter :: ti1adr = rank(findloc(ia1, 1, dim=1)) == 0
+ logical, parameter :: ti1b = all(findloc(ia1, 1, back=.true.) == 5)
+ logical, parameter :: ti1c = all(findloc(ia1, 2, mask=[.false., .false., .true., .true., .true.]) == 4)
+
+ logical, parameter :: ti2a = all(findloc(ia2, 1) == [1, 1])
+ logical, parameter :: ti2ar = rank(findloc(ia2, 1)) == 1
+ logical, parameter :: ti2b = all(findloc(ia2, 1, back=.true.) == [2, 3])
+ logical, parameter :: ti2c = all(findloc(ia2, 2, mask=reshape([.false., .false., .true., .true., .true., .false.], shape(ia2))) == [1, 3])
+ logical, parameter :: ti2d = all(findloc(ia2, 1, dim=1) == [1, 0, 2])
+ logical, parameter :: ti2e = all(findloc(ia2, 1, dim=2) == [1, 3])
+end module
More information about the flang-commits
mailing list