[flang-commits] [flang] cc1d13f - [flang] Fold MAXLOC and MINLOC
peter klausler via flang-commits
flang-commits at lists.llvm.org
Tue Oct 5 11:22:10 PDT 2021
Author: peter klausler
Date: 2021-10-05T11:22:02-07:00
New Revision: cc1d13f997f6db6f2e6c209b9449695b91a68e32
URL: https://github.com/llvm/llvm-project/commit/cc1d13f997f6db6f2e6c209b9449695b91a68e32
DIFF: https://github.com/llvm/llvm-project/commit/cc1d13f997f6db6f2e6c209b9449695b91a68e32.diff
LOG: [flang] Fold MAXLOC and MINLOC
Generalize the code that folds FINDLOC to also handle
folding for MAXLOC and MINLOC.
Differential Revision: https://reviews.llvm.org/D110951
Added:
Modified:
flang/include/flang/Evaluate/type.h
flang/lib/Evaluate/fold-character.cpp
flang/lib/Evaluate/fold-integer.cpp
flang/lib/Evaluate/fold-real.cpp
flang/lib/Evaluate/fold-reduction.h
flang/test/Evaluate/folding30.f90
Removed:
################################################################################
diff --git a/flang/include/flang/Evaluate/type.h b/flang/include/flang/Evaluate/type.h
index a57d8107b7e33..23fc6cab809d8 100644
--- a/flang/include/flang/Evaluate/type.h
+++ b/flang/include/flang/Evaluate/type.h
@@ -335,8 +335,10 @@ using LogicalTypes = CategoryTypes<TypeCategory::Logical>;
using FloatingTypes = common::CombineTuples<RealTypes, ComplexTypes>;
using NumericTypes = common::CombineTuples<IntegerTypes, FloatingTypes>;
-using RelationalTypes = common::CombineTuples<NumericTypes, CharacterTypes>;
-using AllIntrinsicTypes = common::CombineTuples<RelationalTypes, LogicalTypes>;
+using RelationalTypes =
+ common::CombineTuples<IntegerTypes, RealTypes, CharacterTypes>;
+using AllIntrinsicTypes =
+ common::CombineTuples<NumericTypes, CharacterTypes, LogicalTypes>;
using LengthlessIntrinsicTypes =
common::CombineTuples<NumericTypes, LogicalTypes>;
diff --git a/flang/lib/Evaluate/fold-character.cpp b/flang/lib/Evaluate/fold-character.cpp
index 7c37e4cb011fb..72e124a6a84b2 100644
--- a/flang/lib/Evaluate/fold-character.cpp
+++ b/flang/lib/Evaluate/fold-character.cpp
@@ -102,7 +102,7 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
CharacterUtils<KIND>::TRIM(std::get<Scalar<T>>(*scalar))}};
}
}
- // TODO: maxloc, minloc, transfer
+ // TODO: transfer
return Expr<T>{std::move(funcRef)};
}
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index 441a5a9924f98..86fb46ba6094e 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -195,42 +195,58 @@ static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
return Expr<T>{std::move(ref)};
}
-// FINDLOC()
-class FindlocHelper {
+// FINDLOC(), MAXLOC(), & MINLOC()
+enum class WhichLocation { Findloc, Maxloc, Minloc };
+template <WhichLocation WHICH> class LocationHelper {
public:
- FindlocHelper(
+ LocationHelper(
DynamicType &&type, ActualArguments &arg, FoldingContext &context)
: type_{type}, arg_{arg}, context_{context} {}
using Result = std::optional<Constant<SubscriptInteger>>;
- using Types = AllIntrinsicTypes;
+ using Types = std::conditional_t<WHICH == WhichLocation::Findloc,
+ AllIntrinsicTypes, RelationalTypes>;
template <typename T> Result Test() const {
if (T::category != type_.category() || T::kind != type_.kind()) {
return std::nullopt;
}
- CHECK(arg_.size() == 6);
+ CHECK(arg_.size() == (WHICH == WhichLocation::Findloc ? 6 : 5));
Folder<T> folder{context_};
Constant<T> *array{folder.Folding(arg_[0])};
- Constant<T> *value{folder.Folding(arg_[1])};
- if (!array || !value) {
+ if (!array) {
return std::nullopt;
}
+ std::optional<Constant<T>> value;
+ if constexpr (WHICH == WhichLocation::Findloc) {
+ if (const Constant<T> *p{folder.Folding(arg_[1])}) {
+ value.emplace(*p);
+ } else {
+ return std::nullopt;
+ }
+ }
std::optional<int> dim;
Constant<LogicalResult> *mask{
- GetReductionMASK(arg_[3], array->shape(), context_)};
- if ((!mask && arg_[3]) ||
- !CheckReductionDIM(dim, context_, arg_, 2, array->Rank())) {
+ GetReductionMASK(arg_[maskArg], array->shape(), context_)};
+ if ((!mask && arg_[maskArg]) ||
+ !CheckReductionDIM(dim, context_, arg_, dimArg, array->Rank())) {
return std::nullopt;
}
bool back{false};
- if (arg_[5]) {
- const auto *backConst{Folder<LogicalResult>{context_}.Folding(arg_[5])};
+ if (arg_[backArg]) {
+ const auto *backConst{
+ Folder<LogicalResult>{context_}.Folding(arg_[backArg])};
if (backConst) {
back = backConst->GetScalarValue().value().IsTrue();
} else {
return std::nullopt;
}
}
+ const RelationalOperator relation{WHICH == WhichLocation::Findloc
+ ? RelationalOperator::EQ
+ : WHICH == WhichLocation::Maxloc
+ ? (back ? RelationalOperator::GE : RelationalOperator::GT)
+ : back ? RelationalOperator::LE
+ : RelationalOperator::LT};
// Use lower bounds of 1 exclusively.
array->SetLowerBoundsToOne();
ConstantSubscripts at{array->lbounds()}, maskAt, resultIndices, resultShape;
@@ -252,10 +268,11 @@ class FindlocHelper {
ConstantSubscript n{GetSize(resultShape)};
for (ConstantSubscript j{0}; j < n; ++j) {
ConstantSubscript hit{array->lbounds()[zbDim] - 1};
+ value.reset();
for (ConstantSubscript k{0}; k < dimLength;
++k, ++at[zbDim], mask && ++maskAt[zbDim]) {
if ((!mask || mask->At(maskAt).IsTrue()) &&
- IsHit(array->At(at), *value)) {
+ IsHit(array->At(at), value, relation)) {
hit = at[zbDim];
if (!back) {
break;
@@ -279,7 +296,7 @@ class FindlocHelper {
for (ConstantSubscript j{0}; j < n; ++j, array->IncrementSubscripts(at),
mask && mask->IncrementSubscripts(maskAt)) {
if ((!mask || mask->At(maskAt).IsTrue()) &&
- IsHit(array->At(at), *value)) {
+ IsHit(array->At(at), value, relation)) {
resultIndices = at;
if (!back) {
break;
@@ -297,43 +314,57 @@ class FindlocHelper {
private:
template <typename T>
- bool IsHit(typename Constant<T>::Element element, Constant<T> value) const {
+ bool IsHit(typename Constant<T>::Element element,
+ std::optional<Constant<T>> &value,
+ [[maybe_unused]] RelationalOperator relation) const {
std::optional<Expr<LogicalResult>> cmp;
- if constexpr (T::category == TypeCategory::Logical) {
- // array(at) .EQV. value?
- cmp.emplace(
- ConvertToType<LogicalResult>(Expr<T>{LogicalOperation<T::kind>{
- LogicalOperator::Eqv, Expr<T>{Constant<T>{std::move(element)}},
- Expr<T>{std::move(value)}}}));
- } else { // array(at) .EQ. value?
- cmp.emplace(PackageRelation(RelationalOperator::EQ,
- Expr<T>{Constant<T>{std::move(element)}}, Expr<T>{std::move(value)}));
- }
- Expr<LogicalResult> folded{Fold(context_, std::move(*cmp))};
- return GetScalarConstantValue<LogicalResult>(folded).value().IsTrue();
+ if (value) {
+ if constexpr (T::category == TypeCategory::Logical) {
+ // array(at) .EQV. value?
+ static_assert(WHICH == WhichLocation::Findloc);
+ cmp.emplace(
+ ConvertToType<LogicalResult>(Expr<T>{LogicalOperation<T::kind>{
+ LogicalOperator::Eqv, Expr<T>{Constant<T>{std::move(element)}},
+ Expr<T>{Constant<T>{*value}}}}));
+ } else { // compare array(at) to value
+ cmp.emplace(
+ PackageRelation(relation, Expr<T>{Constant<T>{std::move(element)}},
+ Expr<T>{Constant<T>{*value}}));
+ }
+ Expr<LogicalResult> folded{Fold(context_, std::move(*cmp))};
+ return GetScalarConstantValue<LogicalResult>(folded).value().IsTrue();
+ } else { // first unmasked element seen for MAXLOC/MINLOC
+ value.emplace(std::move(element));
+ return true;
+ }
}
+ static constexpr int dimArg{WHICH == WhichLocation::Findloc ? 2 : 1};
+ static constexpr int maskArg{dimArg + 1};
+ static constexpr int backArg{maskArg + 2};
+
DynamicType type_;
ActualArguments &arg_;
FoldingContext &context_;
};
-static std::optional<Constant<SubscriptInteger>> FoldFindlocCall(
+template <WhichLocation which>
+static std::optional<Constant<SubscriptInteger>> FoldLocationCall(
ActualArguments &arg, FoldingContext &context) {
- CHECK(arg.size() == 6);
if (arg[0]) {
if (auto type{arg[0]->GetType()}) {
- return common::SearchTypes(FindlocHelper{std::move(*type), arg, context});
+ return common::SearchTypes(
+ LocationHelper<which>{std::move(*type), arg, context});
}
}
return std::nullopt;
}
-template <typename T>
-static Expr<T> FoldFindloc(FoldingContext &context, FunctionRef<T> &&ref) {
+template <WhichLocation which, typename T>
+static Expr<T> FoldLocation(FoldingContext &context, FunctionRef<T> &&ref) {
static_assert(T::category == TypeCategory::Integer);
if (std::optional<Constant<SubscriptInteger>> found{
- FoldFindlocCall(ref.arguments(), context)}) {
+ FoldLocationCall<which>(ref.arguments(), context)}) {
return Expr<T>{Fold(
context, ConvertToType<T>(Expr<SubscriptInteger>{std::move(*found)}))};
} else {
@@ -451,7 +482,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
DIE("exponent argument must be real");
}
} else if (name == "findloc") {
- return FoldFindloc<T>(context, std::move(funcRef));
+ return FoldLocation<WhichLocation::Findloc, T>(context, std::move(funcRef));
} else if (name == "huge") {
return Expr<T>{Scalar<T>::HUGE()};
} else if (name == "iachar" || name == "ichar") {
@@ -661,6 +692,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
},
sx->u);
}
+ } else if (name == "maxloc") {
+ return FoldLocation<WhichLocation::Maxloc, T>(context, std::move(funcRef));
} else if (name == "maxval") {
return FoldMaxvalMinval<T>(context, std::move(funcRef),
RelationalOperator::GT, T::Scalar::Least());
@@ -669,6 +702,10 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
} else if (name == "merge_bits") {
return FoldElementalIntrinsic<T, T, T, T>(
context, std::move(funcRef), &Scalar<T>::MERGE_BITS);
+ } else if (name == "min") {
+ return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
+ } else if (name == "min0" || name == "min1") {
+ return RewriteSpecificMINorMAX(context, std::move(funcRef));
} else if (name == "minexponent") {
if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
return std::visit(
@@ -678,10 +715,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
},
sx->u);
}
- } else if (name == "min") {
- return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
- } else if (name == "min0" || name == "min1") {
- return RewriteSpecificMINorMAX(context, std::move(funcRef));
+ } else if (name == "minloc") {
+ return FoldLocation<WhichLocation::Minloc, T>(context, std::move(funcRef));
} else if (name == "minval") {
return FoldMaxvalMinval<T>(
context, std::move(funcRef), RelationalOperator::LT, T::Scalar::HUGE());
@@ -853,8 +888,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
} else if (name == "ubound") {
return UBOUND(context, std::move(funcRef));
}
- // TODO: dot_product, ibits, image_status, ishftc,
- // matmul, maxloc, minloc, sign, transfer
+ // TODO: dot_product, ibits, ishftc, matmul, sign, transfer
return Expr<T>{std::move(funcRef)};
}
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index bffd8ea757555..93b3ce20a75b5 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -143,7 +143,7 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
return Expr<T>{Scalar<T>::TINY()};
}
// TODO: dim, dot_product, fraction, matmul,
- // maxloc, minloc, modulo, nearest, norm2, rrspacing, scale,
+ // modulo, nearest, norm2, rrspacing, scale,
// set_exponent, spacing, transfer,
// bessel_jn (transformational) and bessel_yn (transformational)
return Expr<T>{std::move(funcRef)};
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index bfa4a1f80c79b..c7065c2c6bf62 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-// TODO: DOT_PRODUCT, NORM2, MAXLOC, MINLOC, PARITY
+// TODO: DOT_PRODUCT, NORM2, PARITY
#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
diff --git a/flang/test/Evaluate/folding30.f90 b/flang/test/Evaluate/folding30.f90
index 748723c08e88d..d0bbfde5ee480 100644
--- a/flang/test/Evaluate/folding30.f90
+++ b/flang/test/Evaluate/folding30.f90
@@ -1,21 +1,57 @@
! RUN: %python %S/test_folding.py %s %flang_fc1
-! Tests folding of FINDLOC
+! Tests folding of FINDLOC, MAXLOC, & MINLOC
module m1
integer, parameter :: ia1(2:6) = [1, 2, 3, 2, 1]
integer, parameter :: ia2(2:3,2:4) = reshape([1, 2, 3, 3, 2, 1], shape(ia2))
- logical, parameter :: ti1a = all(findloc(ia1, 1) == 1)
- logical, parameter :: ti1ar = rank(findloc(ia1, 1)) == 1
- logical, parameter :: ti1ak = kind(findloc(ia1, 1, kind=2)) == 2
- logical, parameter :: ti1ad = findloc(ia1, 1, dim=1) == 1
- logical, parameter :: ti1adr = rank(findloc(ia1, 1, dim=1)) == 0
- logical, parameter :: ti1b = all(findloc(ia1, 1, back=.true.) == 5)
- logical, parameter :: ti1c = all(findloc(ia1, 2, mask=[.false., .false., .true., .true., .true.]) == 4)
+ logical, parameter :: fi1a = all(findloc(ia1, 1) == 1)
+ logical, parameter :: fi1ar = rank(findloc(ia1, 1)) == 1
+ logical, parameter :: fi1ak = kind(findloc(ia1, 1, kind=2)) == 2
+ logical, parameter :: fi1ad = findloc(ia1, 1, dim=1) == 1
+ logical, parameter :: fi1adr = rank(findloc(ia1, 1, dim=1)) == 0
+ logical, parameter :: fi1b = all(findloc(ia1, 1, back=.true.) == 5)
+ logical, parameter :: fi1c = all(findloc(ia1, 2, mask=[.false., .false., .true., .true., .true.]) == 4)
- logical, parameter :: ti2a = all(findloc(ia2, 1) == [1, 1])
- logical, parameter :: ti2ar = rank(findloc(ia2, 1)) == 1
- logical, parameter :: ti2b = all(findloc(ia2, 1, back=.true.) == [2, 3])
- logical, parameter :: ti2c = all(findloc(ia2, 2, mask=reshape([.false., .false., .true., .true., .true., .false.], shape(ia2))) == [1, 3])
- logical, parameter :: ti2d = all(findloc(ia2, 1, dim=1) == [1, 0, 2])
- logical, parameter :: ti2e = all(findloc(ia2, 1, dim=2) == [1, 3])
+ logical, parameter :: fi2a = all(findloc(ia2, 1) == [1, 1])
+ logical, parameter :: fi2ar = rank(findloc(ia2, 1)) == 1
+ logical, parameter :: fi2b = all(findloc(ia2, 1, back=.true.) == [2, 3])
+ logical, parameter :: fi2c = all(findloc(ia2, 2, mask=reshape([.false., .false., .true., .true., .true., .false.], shape(ia2))) == [1, 3])
+ logical, parameter :: fi2d = all(findloc(ia2, 1, dim=1) == [1, 0, 2])
+ logical, parameter :: fi2e = all(findloc(ia2, 1, dim=2) == [1, 3])
+
+ logical, parameter :: xi1a = all(maxloc(ia1) == 3)
+ logical, parameter :: xi1ar = rank(maxloc(ia1)) == 1
+ logical, parameter :: xi1ak = kind(maxloc(ia1, kind=2)) == 2
+ logical, parameter :: xi1ad = maxloc(ia1, dim=1) == 3
+ logical, parameter :: xi1adr = rank(maxloc(ia1, dim=1)) == 0
+ logical, parameter :: xi1b = all(maxloc(ia1, back=.true.) == 3)
+ logical, parameter :: xi1c = all(maxloc(ia1, mask=[.false., .true., .false., .true., .true.]) == 2)
+ logical, parameter :: xi1d = all(maxloc(ia1, mask=[.false., .true., .false., .true., .true.], back=.true.) == 4)
+
+ logical, parameter :: xi2a = all(maxloc(ia2) == [1, 2])
+ logical, parameter :: xi2ar = rank(maxloc(ia2)) == 1
+ logical, parameter :: xi2b = all(maxloc(ia2, back=.true.) == [2, 2])
+ logical, parameter :: xi2c = all(maxloc(ia2, mask=reshape([.false., .true., .true., .false., .true., .true.], shape(ia2))) == [2, 1])
+ logical, parameter :: xi2d = all(maxloc(ia2, mask=reshape([.false., .true., .true., .false., .true., .true.], shape(ia2)), back=.true.) == [1, 3])
+ logical, parameter :: xi2e = all(maxloc(ia2, dim=1) == [2, 1, 1])
+ logical, parameter :: xi2f = all(maxloc(ia2, dim=1, back=.true.) == [2, 2, 1])
+ logical, parameter :: xi2g = all(maxloc(ia2, dim=2) == [2, 2])
+
+ logical, parameter :: ni1a = all(minloc(ia1) == 1)
+ logical, parameter :: ni1ar = rank(minloc(ia1)) == 1
+ logical, parameter :: ni1ak = kind(minloc(ia1, kind=2)) == 2
+ logical, parameter :: ni1ad = minloc(ia1, dim=1) == 1
+ logical, parameter :: ni1adr = rank(minloc(ia1, dim=1)) == 0
+ logical, parameter :: ni1b = all(minloc(ia1, back=.true.) == 5)
+ logical, parameter :: ni1c = all(minloc(ia1, mask=[.false., .true., .true., .true., .false.]) == 2)
+ logical, parameter :: ni1d = all(minloc(ia1, mask=[.false., .true., .true., .true., .false.], back=.true.) == 4)
+
+ logical, parameter :: ni2a = all(minloc(ia2) == [1, 1])
+ logical, parameter :: ni2ar = rank(minloc(ia2)) == 1
+ logical, parameter :: ni2b = all(minloc(ia2, back=.true.) == [2, 3])
+ logical, parameter :: ni2c = all(minloc(ia2, mask=reshape([.false., .true., .true., .false., .true., .false.], shape(ia2))) == [2, 1])
+ logical, parameter :: ni2d = all(minloc(ia2, mask=reshape([.false., .true., .true., .false., .true., .false.], shape(ia2)), back=.true.) == [1, 3])
+ logical, parameter :: ni2e = all(minloc(ia2, dim=1) == [1, 1, 2])
+ logical, parameter :: ni2f = all(minloc(ia2, dim=1, back=.true.) == [1, 2, 2])
+ logical, parameter :: ni2g = all(minloc(ia2, dim=2) == [1, 3])
end module
More information about the flang-commits
mailing list