[flang-commits] [flang] 26aff84 - [flang] Fold COUNT()
peter klausler via flang-commits
flang-commits at lists.llvm.org
Thu Sep 16 17:09:31 PDT 2021
Author: peter klausler
Date: 2021-09-16T17:09:23-07:00
New Revision: 26aff847d8860c14bc3e829e4bfe7980058504c0
URL: https://github.com/llvm/llvm-project/commit/26aff847d8860c14bc3e829e4bfe7980058504c0
DIFF: https://github.com/llvm/llvm-project/commit/26aff847d8860c14bc3e829e4bfe7980058504c0.diff
LOG: [flang] Fold COUNT()
Complete folding of the intrinsic reduction function COUNT() for all
cases, including partial reductions with DIM= arguments.
Differential Revision: https://reviews.llvm.org/D109911
Added:
flang/lib/Evaluate/fold-reduction.cpp
flang/test/Evaluate/folding29.f90
Modified:
flang/lib/Evaluate/CMakeLists.txt
flang/lib/Evaluate/fold-implementation.h
flang/lib/Evaluate/fold-integer.cpp
flang/lib/Evaluate/fold-logical.cpp
flang/lib/Evaluate/fold-reduction.h
Removed:
################################################################################
diff --git a/flang/lib/Evaluate/CMakeLists.txt b/flang/lib/Evaluate/CMakeLists.txt
index a2fdc10896b43..2b8eafafd333d 100644
--- a/flang/lib/Evaluate/CMakeLists.txt
+++ b/flang/lib/Evaluate/CMakeLists.txt
@@ -24,6 +24,7 @@ add_flang_library(FortranEvaluate
fold-integer.cpp
fold-logical.cpp
fold-real.cpp
+ fold-reduction.cpp
formatting.cpp
host.cpp
initial-image.cpp
diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index f68e2ea0acd4d..c61637263a630 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -492,7 +492,7 @@ Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
// Build and return constant result
if constexpr (TR::category == TypeCategory::Character) {
auto len{static_cast<ConstantSubscript>(
- results.size() ? results[0].length() : 0)};
+ results.empty() ? 0 : results[0].length())};
return Expr<TR>{Constant<TR>{len, std::move(results), std::move(shape)}};
} else {
return Expr<TR>{Constant<TR>{std::move(results), std::move(shape)}};
@@ -944,7 +944,7 @@ Expr<T> FoldMINorMAX(
if (constantArgs.size() != funcRef.arguments().size()) {
return Expr<T>(std::move(funcRef));
}
- CHECK(constantArgs.size() > 0);
+ CHECK(!constantArgs.empty());
Expr<T> result{std::move(*constantArgs[0])};
for (std::size_t i{1}; i < constantArgs.size(); ++i) {
Extremum<T> extremum{order, result, Expr<T>{std::move(*constantArgs[i])}};
@@ -1075,7 +1075,7 @@ template <typename T> class ArrayConstructorFolder {
Expr<T> folded{Fold(context_, common::Clone(expr.value()))};
if (const auto *c{UnwrapConstantValue<T>(folded)}) {
// Copy elements in Fortran array element order
- if (c->size() > 0) {
+ if (!c->empty()) {
ConstantSubscripts index{c->lbounds()};
do {
elements_.emplace_back(c->At(index));
@@ -1156,7 +1156,7 @@ template <typename T>
std::optional<Expr<T>> AsFlatArrayConstructor(const Expr<T> &expr) {
if (const auto *c{UnwrapConstantValue<T>(expr)}) {
ArrayConstructor<T> result{expr};
- if (c->size() > 0) {
+ if (!c->empty()) {
ConstantSubscripts at{c->lbounds()};
do {
result.Push(Expr<T>{Constant<T>{c->At(at)}});
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index 032e0da273d18..3fdf252407e92 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -174,21 +174,47 @@ Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
return Expr<T>{std::move(funcRef)};
}
+// COUNT()
+template <typename T>
+static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
+ static_assert(T::category == TypeCategory::Integer);
+ ActualArguments &arg{ref.arguments()};
+ if (const Constant<LogicalResult> *mask{arg.empty()
+ ? nullptr
+ : Folder<LogicalResult>{context}.Folding(arg[0])}) {
+ std::optional<ConstantSubscript> dim;
+ if (arg.size() > 1 && arg[1]) {
+ dim = CheckDIM(context, arg[1], mask->Rank());
+ if (!dim) {
+ mask = nullptr;
+ }
+ }
+ if (mask) {
+ auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
+ if (mask->At(at).IsTrue()) {
+ element = element.AddSigned(Scalar<T>{1}).value;
+ }
+ }};
+ return Expr<T>{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
+ }
+ }
+ return Expr<T>{std::move(ref)};
+}
+
// for IALL, IANY, & IPARITY
template <typename T>
static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
Scalar<T> identity) {
static_assert(T::category == TypeCategory::Integer);
- using Element = Scalar<T>;
std::optional<ConstantSubscript> dim;
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
- auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
+ auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
element = (element.*operation)(array->At(at));
}};
- return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
+ return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
}
return Expr<T>{std::move(ref)};
}
@@ -237,17 +263,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
cx->u);
}
} else if (name == "count") {
- if (!args[1]) { // TODO: COUNT(x,DIM=d)
- if (const auto *constant{UnwrapConstantValue<LogicalResult>(args[0])}) {
- std::int64_t result{0};
- for (const auto &element : constant->values()) {
- if (element.IsTrue()) {
- ++result;
- }
- }
- return Expr<T>{result};
- }
- }
+ return FoldCount<T>(context, std::move(funcRef));
} else if (name == "digits") {
if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
return Expr<T>{std::visit(
diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp
index 27a2f0c36b0f0..71a8f701991bd 100644
--- a/flang/lib/Evaluate/fold-logical.cpp
+++ b/flang/lib/Evaluate/fold-logical.cpp
@@ -26,7 +26,7 @@ static Expr<T> FoldAllAny(FoldingContext &context, FunctionRef<T> &&ref,
auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
element = (element.*operation)(array->At(at));
}};
- return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
+ return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
}
return Expr<T>{std::move(ref)};
}
diff --git a/flang/lib/Evaluate/fold-reduction.cpp b/flang/lib/Evaluate/fold-reduction.cpp
new file mode 100644
index 0000000000000..f171f859dc064
--- /dev/null
+++ b/flang/lib/Evaluate/fold-reduction.cpp
@@ -0,0 +1,32 @@
+//===-- lib/Evaluate/fold-reduction.cpp -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "fold-reduction.h"
+
+namespace Fortran::evaluate {
+
+std::optional<ConstantSubscript> CheckDIM(
+ FoldingContext &context, std::optional<ActualArgument> &arg, int rank) {
+ if (arg) {
+ if (auto *dimConst{Folder<SubscriptInteger>{context}.Folding(arg)}) {
+ if (auto dimScalar{dimConst->GetScalarValue()}) {
+ auto dim{dimScalar->ToInt64()};
+ if (dim >= 1 && dim <= rank) {
+ return {dim};
+ } else {
+ context.messages().Say(
+ "DIM=%jd is not valid for an array of rank %d"_err_en_US,
+ static_cast<std::intmax_t>(dim), rank);
+ }
+ }
+ }
+ }
+ return std::nullopt;
+}
+
+} // namespace Fortran::evaluate
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index 4b265ecf4716a..714de7c2b51d5 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -6,8 +6,7 @@
//
//===----------------------------------------------------------------------===//
-// TODO: ALL, ANY, COUNT, DOT_PRODUCT, FINDLOC, IALL, IANY, IPARITY,
-// NORM2, MAXLOC, MINLOC, PARITY, PRODUCT, SUM
+// TODO: DOT_PRODUCT, FINDLOC, NORM2, MAXLOC, MINLOC, PARITY
#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
@@ -16,6 +15,10 @@
namespace Fortran::evaluate {
+// Folds & validates a DIM= actual argument.
+std::optional<ConstantSubscript> CheckDIM(
+ FoldingContext &, std::optional<ActualArgument> &, int rank);
+
// Common preprocessing for reduction transformational intrinsic function
// folding. If the intrinsic can have DIM= &/or MASK= arguments, extract
// and check them. If a MASK= is present, apply it to the array data and
@@ -35,18 +38,7 @@ static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
return std::nullopt;
}
if (dimIndex && arg.size() >= *dimIndex + 1 && arg[*dimIndex]) {
- if (auto *dimConst{
- Folder<SubscriptInteger>{context}.Folding(arg[*dimIndex])}) {
- if (auto dimScalar{dimConst->GetScalarValue()}) {
- dim.emplace(dimScalar->ToInt64());
- if (*dim < 1 || *dim > folded->Rank()) {
- context.messages().Say(
- "DIM=%jd is not valid for an array of rank %d"_err_en_US,
- static_cast<std::intmax_t>(*dim), folded->Rank());
- dim.reset();
- }
- }
- }
+ dim = CheckDIM(context, arg[*dimIndex], folded->Rank());
if (!dim) {
return std::nullopt;
}
@@ -96,8 +88,8 @@ static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
// Generalized reduction to an array of one dimension fewer (w/ DIM=)
// or to a scalar (w/o DIM=).
-template <typename T, typename ACCUMULATOR>
-static Constant<T> DoReduction(const Constant<T> &array,
+template <typename T, typename ACCUMULATOR, typename ARRAY>
+static Constant<T> DoReduction(const Constant<ARRAY> &array,
std::optional<ConstantSubscript> &dim, const Scalar<T> &identity,
ACCUMULATOR &accumulator) {
ConstantSubscripts at{array.lbounds()};
@@ -154,7 +146,7 @@ static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
element = array->At(at);
}
}};
- return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
+ return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
}
return Expr<T>{std::move(ref)};
}
@@ -187,7 +179,7 @@ static Expr<T> FoldProduct(
context.messages().Say(
"PRODUCT() of %s data overflowed"_en_US, T::AsFortran());
} else {
- return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
+ return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
}
}
return Expr<T>{std::move(ref)};
@@ -226,7 +218,7 @@ static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
context.messages().Say(
"SUM() of %s data overflowed"_en_US, T::AsFortran());
} else {
- return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
+ return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
}
}
return Expr<T>{std::move(ref)};
diff --git a/flang/test/Evaluate/folding29.f90 b/flang/test/Evaluate/folding29.f90
new file mode 100644
index 0000000000000..c0ab0631a0c42
--- /dev/null
+++ b/flang/test/Evaluate/folding29.f90
@@ -0,0 +1,11 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of COUNT()
+module m
+ logical, parameter :: arr(3,4) = reshape([(modulo(j, 2) == 1, j = 1, size(arr))], shape(arr))
+ logical, parameter :: test_1 = count([1, 2, 3, 2, 1] < [(j, j=1, 5)]) == 2
+ logical, parameter :: test_2 = count(arr) == 6
+ logical, parameter :: test_3 = all(count(arr, dim=1) == [2, 1, 2, 1])
+ logical, parameter :: test_4 = all(count(arr, dim=2) == [2, 2, 2])
+ logical, parameter :: test_5 = count(logical(arr, kind=1)) == 6
+ logical, parameter :: test_6 = count(logical(arr, kind=2)) == 6
+end module
More information about the flang-commits
mailing list