[flang-commits] [flang] [flang][runtime] Treatment of NaN in MAXVAL/MAXLOC/MINVAL/MINLOC (PR #76999)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Thu Jan 4 12:39:24 PST 2024


https://github.com/klausler created https://github.com/llvm/llvm-project/pull/76999

Detect NaN elements in data and handle them like gfortran does (at runtime); namely, NaN can be returned if all the data are NaNs, but any non-NaN value is preferable.  Ensure that folding returns the same results as runtime computation.

Fixes llvm-test-suite/Fortran/gfortran/regression/maxloc_2.f90 (and probably others).

>From 24f53e5042e7dfe20a1d038a220dedf7effeb5f8 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Thu, 4 Jan 2024 10:26:00 -0800
Subject: [PATCH] [flang][runtime] Treatment of NaN in
 MAXVAL/MAXLOC/MINVAL/MINLOC

Detect NaN elements in data and handle them like gfortran does
(at runtime); namely, NaN can be returned if all the data are NaNs,
but any non-NaN value is preferable.  Ensure that folding returns
the same results as runtime computation.

Fixes llvm-test-suite/Fortran/gfortran/regression/maxloc_2.f90
(and probably others).
---
 flang/docs/Extensions.md              |  4 ++
 flang/lib/Evaluate/fold-character.cpp |  8 ++--
 flang/lib/Evaluate/fold-integer.cpp   | 37 ++++++++++------
 flang/lib/Evaluate/fold-logical.cpp   |  6 +--
 flang/lib/Evaluate/fold-real.cpp      | 16 ++++---
 flang/lib/Evaluate/fold-reduction.h   | 62 +++++++++++++++++++--------
 flang/runtime/extrema.cpp             | 14 +++++-
 flang/test/Evaluate/fold-findloc.f90  | 12 ++++++
 flang/test/Evaluate/folding20.f90     | 10 +++++
 9 files changed, 122 insertions(+), 47 deletions(-)

diff --git a/flang/docs/Extensions.md b/flang/docs/Extensions.md
index 16eb67f2e27c81..94f4b31ec3822c 100644
--- a/flang/docs/Extensions.md
+++ b/flang/docs/Extensions.md
@@ -654,6 +654,10 @@ end
   we don't round.  This seems to be how the Intel Fortran compilers
   behave.
 
+* For real `MAXVAL`, `MINVAL`, `MAXLOC`, and `MINLOC`, NaN values are
+  essentially ignored unless there are some unmasked array entries and
+  *all* of them are NaNs.
+
 ## De Facto Standard Features
 
 * `EXTENDS_TYPE_OF()` returns `.TRUE.` if both of its arguments have the
diff --git a/flang/lib/Evaluate/fold-character.cpp b/flang/lib/Evaluate/fold-character.cpp
index a599815fa7aee9..6633e0b97f0d45 100644
--- a/flang/lib/Evaluate/fold-character.cpp
+++ b/flang/lib/Evaluate/fold-character.cpp
@@ -77,8 +77,8 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
     SingleCharType least{0};
     if (auto identity{Identity<T>(
             StringType{least}, GetConstantLength(context, funcRef, 0))}) {
-      return FoldMaxvalMinval<T>(
-          context, std::move(funcRef), RelationalOperator::GT, *identity);
+      return FoldMaxvalMinval<T>(context, std::move(funcRef),
+          RelationalOperator::GT, *identity, *identity);
     }
   } else if (name == "min") {
     return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
@@ -87,8 +87,8 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
     SingleCharType most{0x7fffffff >> (8 * (4 - KIND))};
     if (auto identity{Identity<T>(
             StringType{most}, GetConstantLength(context, funcRef, 0))}) {
-      return FoldMaxvalMinval<T>(
-          context, std::move(funcRef), RelationalOperator::LT, *identity);
+      return FoldMaxvalMinval<T>(context, std::move(funcRef),
+          RelationalOperator::LT, *identity, *identity);
     }
   } else if (name == "new_line") {
     return Expr<T>{Constant<T>{CharacterUtils<KIND>::NEW_LINE()}};
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index ba4bc6a04750ff..3169503837b9c3 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -270,7 +270,8 @@ template <typename T, int MASK_KIND> class CountAccumulator {
 
 public:
   CountAccumulator(const Constant<MaskT> &mask) : mask_{mask} {}
-  void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
+  void operator()(
+      Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
     if (mask_.At(at).IsTrue()) {
       auto incremented{element.AddSigned(Scalar<T>{1})};
       overflow_ |= incremented.overflow;
@@ -395,7 +396,7 @@ template <WhichLocation WHICH> class LocationHelper {
         for (ConstantSubscript k{0}; k < dimLength;
              ++k, ++at[zbDim], mask && ++maskAt[zbDim]) {
           if ((!mask || mask->At(maskAt).IsTrue()) &&
-              IsHit(array->At(at), value, relation)) {
+              IsHit(array->At(at), value, relation, back)) {
             hit = at[zbDim];
             if constexpr (WHICH == WhichLocation::Findloc) {
               if (!back) {
@@ -422,7 +423,7 @@ template <WhichLocation WHICH> class LocationHelper {
       for (ConstantSubscript j{0}; j < n; ++j, array->IncrementSubscripts(at),
            mask && mask->IncrementSubscripts(maskAt)) {
         if ((!mask || mask->At(maskAt).IsTrue()) &&
-            IsHit(array->At(at), value, relation)) {
+            IsHit(array->At(at), value, relation, back)) {
           resultIndices = at;
           if constexpr (WHICH == WhichLocation::Findloc) {
             if (!back) {
@@ -444,7 +445,8 @@ template <WhichLocation WHICH> class LocationHelper {
   template <typename T>
   bool IsHit(typename Constant<T>::Element element,
       std::optional<Constant<T>> &value,
-      [[maybe_unused]] RelationalOperator relation) const {
+      [[maybe_unused]] RelationalOperator relation,
+      [[maybe_unused]] bool back) const {
     std::optional<Expr<LogicalResult>> cmp;
     bool result{true};
     if (value) {
@@ -455,8 +457,19 @@ template <WhichLocation WHICH> class LocationHelper {
             Expr<T>{LogicalOperation<T::kind>{LogicalOperator::Eqv,
                 Expr<T>{Constant<T>{element}}, Expr<T>{Constant<T>{*value}}}}));
       } else { // compare array(at) to value
-        cmp.emplace(PackageRelation(relation, Expr<T>{Constant<T>{element}},
-            Expr<T>{Constant<T>{*value}}));
+        if constexpr (T::category == TypeCategory::Real &&
+            (WHICH == WhichLocation::Maxloc ||
+                WHICH == WhichLocation::Minloc)) {
+          if (value && value->GetScalarValue().value().IsNotANumber() &&
+              (back || !element.IsNotANumber())) {
+            // Replace NaN
+            cmp.emplace(Constant<LogicalResult>{Scalar<LogicalResult>{true}});
+          }
+        }
+        if (!cmp) {
+          cmp.emplace(PackageRelation(relation, Expr<T>{Constant<T>{element}},
+              Expr<T>{Constant<T>{*value}}));
+        }
       }
       Expr<LogicalResult> folded{Fold(context_, std::move(*cmp))};
       result = GetScalarConstantValue<LogicalResult>(folded).value().IsTrue();
@@ -523,9 +536,9 @@ static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
     Scalar<T> identity) {
   static_assert(T::category == TypeCategory::Integer);
   std::optional<int> dim;
-  if (std::optional<Constant<T>> array{
-          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
-              /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
+  if (std::optional<Constant<T>> array{ProcessReductionArgs<T>(context,
+          ref.arguments(), dim, identity, identity,
+          /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
     OperationAccumulator<T> accumulator{*array, operation};
     return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
   }
@@ -1070,7 +1083,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
     return FoldLocation<WhichLocation::Maxloc, T>(context, std::move(funcRef));
   } else if (name == "maxval") {
     return FoldMaxvalMinval<T>(context, std::move(funcRef),
-        RelationalOperator::GT, T::Scalar::Least());
+        RelationalOperator::GT, T::Scalar::Least(), T::Scalar::Least());
   } else if (name == "merge_bits") {
     return FoldElementalIntrinsic<T, T, T, T>(
         context, std::move(funcRef), &Scalar<T>::MERGE_BITS);
@@ -1090,8 +1103,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
   } else if (name == "minloc") {
     return FoldLocation<WhichLocation::Minloc, T>(context, std::move(funcRef));
   } else if (name == "minval") {
-    return FoldMaxvalMinval<T>(
-        context, std::move(funcRef), RelationalOperator::LT, T::Scalar::HUGE());
+    return FoldMaxvalMinval<T>(context, std::move(funcRef),
+        RelationalOperator::LT, T::Scalar::HUGE(), T::Scalar::HUGE());
   } else if (name == "mod") {
     return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
         ScalarFuncWithContext<T, T, T>(
diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp
index 82a5cb20db9e40..1d76f39a9b252a 100644
--- a/flang/lib/Evaluate/fold-logical.cpp
+++ b/flang/lib/Evaluate/fold-logical.cpp
@@ -31,9 +31,9 @@ static Expr<T> FoldAllAnyParity(FoldingContext &context, FunctionRef<T> &&ref,
     Scalar<T> identity) {
   static_assert(T::category == TypeCategory::Logical);
   std::optional<int> dim;
-  if (std::optional<Constant<T>> array{
-          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
-              /*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
+  if (std::optional<Constant<T>> array{ProcessReductionArgs<T>(context,
+          ref.arguments(), dim, identity, identity,
+          /*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
     OperationAccumulator accumulator{*array, operation};
     return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
   }
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 6ae069df5d7a42..57d9997dd767c5 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -52,7 +52,8 @@ template <int KIND> class Norm2Accumulator {
   Norm2Accumulator(
       const Constant<T> &array, const Constant<T> &maxAbs, Rounding rounding)
       : array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
-  void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
+  void operator()(
+      Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
     // Kahan summation of scaled elements:
     // Naively,
     //   NORM2(A(:)) = SQRT(SUM(A(:)**2))
@@ -115,9 +116,9 @@ static Expr<Type<TypeCategory::Real, KIND>> FoldNorm2(FoldingContext &context,
   using Element = typename Constant<T>::Element;
   std::optional<int> dim;
   const Element identity{};
-  if (std::optional<Constant<T>> array{
-          ProcessReductionArgs<T>(context, funcRef.arguments(), dim, identity,
-              /*X=*/0, /*DIM=*/1)}) {
+  if (std::optional<Constant<T>> array{ProcessReductionArgs<T>(context,
+          funcRef.arguments(), dim, identity, identity,
+          /*X=*/0, /*DIM=*/1)}) {
     MaxvalMinvalAccumulator<T, /*ABS=*/true> maxAbsAccumulator{
         RelationalOperator::GT, context, *array};
     Constant<T> maxAbs{
@@ -276,12 +277,13 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
     return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
   } else if (name == "maxval") {
     return FoldMaxvalMinval<T>(context, std::move(funcRef),
-        RelationalOperator::GT, T::Scalar::HUGE().Negate());
+        RelationalOperator::GT, T::Scalar::HUGE().Negate(),
+        T::Scalar::NotANumber());
   } else if (name == "min") {
     return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
   } else if (name == "minval") {
-    return FoldMaxvalMinval<T>(
-        context, std::move(funcRef), RelationalOperator::LT, T::Scalar::HUGE());
+    return FoldMaxvalMinval<T>(context, std::move(funcRef),
+        RelationalOperator::LT, T::Scalar::HUGE(), T::Scalar::NotANumber());
   } else if (name == "mod") {
     CHECK(args.size() == 2);
     return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index 60c757dc3f4fa8..df1ad35be5a9eb 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -116,12 +116,13 @@ Constant<LogicalResult> *GetReductionMASK(
 // Common preprocessing for reduction transformational intrinsic function
 // folding.  If the intrinsic can have DIM= &/or MASK= arguments, extract
 // and check them.  If a MASK= is present, apply it to the array data and
-// substitute identity values for elements corresponding to .FALSE. in
+// substitute replacement values for elements corresponding to .FALSE. in
 // the mask.  If the result is present, the intrinsic call can be folded.
 template <typename T>
 static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
     ActualArguments &arg, std::optional<int> &dim, const Scalar<T> &identity,
-    int arrayIndex, std::optional<int> dimIndex = std::nullopt,
+    const Scalar<T> &partialMaskReplacement, int arrayIndex,
+    std::optional<int> dimIndex = std::nullopt,
     std::optional<int> maskIndex = std::nullopt) {
   if (arg.empty()) {
     return std::nullopt;
@@ -148,6 +149,15 @@ static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
         }
       } else { // mask is an array; test its elements
         elements = std::vector<typename Constant<T>::Element>(n, identity);
+        // Fill with NaN rather than HUGE()/TINY() for MAXLOC/MINLOC
+        // unless mask is all false.
+        for (std::size_t j{0}; j < n; ++j) {
+          if (mask->values()[j].IsTrue()) {
+            elements = std::vector<typename Constant<T>::Element>(
+                n, partialMaskReplacement);
+            break;
+          }
+        }
         ConstantSubscripts at{folded->lbounds()};
         for (std::size_t j{0}; j < n; ++j, folded->IncrementSubscripts(at)) {
           if (mask->values()[j].IsTrue()) {
@@ -172,7 +182,8 @@ static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
 
 // Generalized reduction to an array of one dimension fewer (w/ DIM=)
 // or to a scalar (w/o DIM=).  The ACCUMULATOR type must define
-// operator()(Scalar<T> &, const ConstantSubscripts &) and Done(Scalar<T> &).
+// operator()(Scalar<T> &, const ConstantSubscripts &, bool first)
+// and Done(Scalar<T> &).
 template <typename T, typename ACCUMULATOR, typename ARRAY>
 static Constant<T> DoReduction(const Constant<ARRAY> &array,
     std::optional<int> &dim, const Scalar<T> &identity,
@@ -191,15 +202,17 @@ static Constant<T> DoReduction(const Constant<ARRAY> &array,
       dimAt = dimLbound;
       elements.push_back(identity);
       for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt) {
-        accumulator(elements.back(), at);
+        accumulator(elements.back(), at, j == 0);
       }
       accumulator.Done(elements.back());
     }
   } else { // no DIM=, result is scalar
     elements.push_back(identity);
+    bool first{true};
     for (auto n{array.size()}; n-- > 0;
          IncrementSubscripts(at, array.shape())) {
-      accumulator(elements.back(), at);
+      accumulator(elements.back(), at, first);
+      first = false;
     }
     accumulator.Done(elements.back());
   }
@@ -217,11 +230,18 @@ template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
   MaxvalMinvalAccumulator(
       RelationalOperator opr, FoldingContext &context, const Constant<T> &array)
       : opr_{opr}, context_{context}, array_{array} {};
-  void operator()(Scalar<T> &element, const ConstantSubscripts &at) const {
+  void operator()(Scalar<T> &element, const ConstantSubscripts &at,
+      [[maybe_unused]] bool first) const {
     auto aAt{array_.At(at)};
     if constexpr (ABS) {
       aAt = aAt.ABS();
     }
+    if constexpr (T::category == TypeCategory::Real) {
+      if (first || element.IsNotANumber()) {
+        element = aAt;
+        return;
+      }
+    }
     Expr<LogicalResult> test{PackageRelation(
         opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})};
     auto folded{GetScalarConstantValue<LogicalResult>(
@@ -241,14 +261,15 @@ template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
 
 template <typename T>
 static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
-    RelationalOperator opr, const Scalar<T> &identity) {
+    RelationalOperator opr, const Scalar<T> &identity,
+    const Scalar<T> &maskedReplacement) {
   static_assert(T::category == TypeCategory::Integer ||
       T::category == TypeCategory::Real ||
       T::category == TypeCategory::Character);
   std::optional<int> dim;
-  if (std::optional<Constant<T>> array{
-          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
-              /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
+  if (std::optional<Constant<T>> array{ProcessReductionArgs<T>(context,
+          ref.arguments(), dim, identity, maskedReplacement,
+          /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
     MaxvalMinvalAccumulator accumulator{opr, context, *array};
     return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
   }
@@ -259,7 +280,8 @@ static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
 template <typename T> class ProductAccumulator {
 public:
   ProductAccumulator(const Constant<T> &array) : array_{array} {}
-  void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
+  void operator()(
+      Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
     if constexpr (T::category == TypeCategory::Integer) {
       auto prod{element.MultiplySigned(array_.At(at))};
       overflow_ |= prod.SignedMultiplicationOverflowed();
@@ -285,9 +307,9 @@ static Expr<T> FoldProduct(
       T::category == TypeCategory::Real ||
       T::category == TypeCategory::Complex);
   std::optional<int> dim;
-  if (std::optional<Constant<T>> array{
-          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
-              /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
+  if (std::optional<Constant<T>> array{ProcessReductionArgs<T>(context,
+          ref.arguments(), dim, identity, identity,
+          /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
     ProductAccumulator accumulator{*array};
     auto result{Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)}};
     if (accumulator.overflow()) {
@@ -306,7 +328,8 @@ template <typename T> class SumAccumulator {
 public:
   SumAccumulator(const Constant<T> &array, Rounding rounding)
       : array_{array}, rounding_{rounding} {}
-  void operator()(Element &element, const ConstantSubscripts &at) {
+  void operator()(
+      Element &element, const ConstantSubscripts &at, bool /*first*/) {
     if constexpr (T::category == TypeCategory::Integer) {
       auto sum{element.AddSigned(array_.At(at))};
       overflow_ |= sum.overflow;
@@ -348,9 +371,9 @@ static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
   using Element = typename Constant<T>::Element;
   std::optional<int> dim;
   Element identity{};
-  if (std::optional<Constant<T>> array{
-          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
-              /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
+  if (std::optional<Constant<T>> array{ProcessReductionArgs<T>(context,
+          ref.arguments(), dim, identity, identity,
+          /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
     SumAccumulator accumulator{
         *array, context.targetCharacteristics().roundingMode()};
     auto result{Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)}};
@@ -369,7 +392,8 @@ template <typename T> class OperationAccumulator {
   OperationAccumulator(const Constant<T> &array,
       Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const)
       : array_{array}, operation_{operation} {}
-  void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
+  void operator()(
+      Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
     element = (element.*operation_)(array_.At(at));
   }
   void Done(Scalar<T> &) const {}
diff --git a/flang/runtime/extrema.cpp b/flang/runtime/extrema.cpp
index edb5d5f47a5acf..5953a286594d83 100644
--- a/flang/runtime/extrema.cpp
+++ b/flang/runtime/extrema.cpp
@@ -19,6 +19,7 @@
 #include <cinttypes>
 #include <cmath>
 #include <optional>
+#include <type_traits>
 
 namespace Fortran::runtime {
 
@@ -28,7 +29,9 @@ template <typename T, bool IS_MAX, bool BACK> struct NumericCompare {
   using Type = T;
   explicit RT_API_ATTRS NumericCompare(std::size_t /*elemLen; ignored*/) {}
   RT_API_ATTRS bool operator()(const T &value, const T &previous) const {
-    if (value == previous) {
+    if (std::is_floating_point_v<T> && previous != previous) {
+      return BACK || value == value; // replace NaN
+    } else if (value == previous) {
       return BACK;
     } else if constexpr (IS_MAX) {
       return value > previous;
@@ -485,6 +488,7 @@ class NumericExtremumAccumulator {
   explicit RT_API_ATTRS NumericExtremumAccumulator(const Descriptor &array)
       : array_{array} {}
   RT_API_ATTRS void Reinitialize() {
+    any_ = false;
     extremum_ = MaxOrMinIdentity<CAT, KIND, IS_MAXVAL>::Value();
   }
   template <typename A>
@@ -492,7 +496,12 @@ class NumericExtremumAccumulator {
     *p = extremum_;
   }
   RT_API_ATTRS bool Accumulate(Type x) {
-    if constexpr (IS_MAXVAL) {
+    if (!any_) {
+      extremum_ = x;
+      any_ = true;
+    } else if (CAT == TypeCategory::Real && extremum_ != extremum_) {
+      extremum_ = x; // replace NaN
+    } else if constexpr (IS_MAXVAL) {
       if (x > extremum_) {
         extremum_ = x;
       }
@@ -508,6 +517,7 @@ class NumericExtremumAccumulator {
 
 private:
   const Descriptor &array_;
+  bool any_{false};
   Type extremum_{MaxOrMinIdentity<CAT, KIND, IS_MAXVAL>::Value()};
 };
 
diff --git a/flang/test/Evaluate/fold-findloc.f90 b/flang/test/Evaluate/fold-findloc.f90
index b8bb85af65dc6c..0da74fddfdfb26 100644
--- a/flang/test/Evaluate/fold-findloc.f90
+++ b/flang/test/Evaluate/fold-findloc.f90
@@ -4,6 +4,9 @@ module m1
   integer, parameter :: ia1(2:6) = [1, 2, 3, 2, 1]
   integer, parameter :: ia2(2:3,2:4) = reshape([1, 2, 3, 3, 2, 1], shape(ia2))
   integer, parameter :: ia3(2,0,2) = 0 ! middle dimension has zero extent
+  real, parameter :: nan = real(z'7FC00000')
+  real, parameter :: nans(*) = [nan, nan]
+  real, parameter :: someNans(*) = [nan, 0.]
 
   logical, parameter :: test_fi1a = all(findloc(ia1, 1) == 1)
   logical, parameter :: test_fi1ar = rank(findloc(ia1, 1)) == 1
@@ -85,4 +88,13 @@ module m1
   logical, parameter:: test_fia1_mfd = all(findloc(ia1, 1, mask=.false., dim=1) == [0])
   logical, parameter:: test_fia2_mfd1 = all(findloc(ia2, 1, dim=1, mask=.false.) == [0, 0, 0])
   logical, parameter:: test_fia2_mfd2 = all(findloc(ia2, 1, dim=2, mask=.false.) == [0, 0])
+
+  logical, parameter :: test_nan1 = maxloc(nans,1) == 1
+  logical, parameter :: test_nan2 = maxloc(nans,1,back=.true.) == 2
+  logical, parameter :: test_nan3 = minloc(nans,1) == 1
+  logical, parameter :: test_nan4 = minloc(nans,1,back=.true.) == 2
+  logical, parameter :: test_nan5 = maxloc(someNans,1) == 2
+  logical, parameter :: test_nan6 = maxloc(someNans,1,back=.true.) == 2
+  logical, parameter :: test_nan7 = minloc(someNans,1) == 2
+  logical, parameter :: test_nan8 = minloc(someNans,1,back=.true.) == 2
 end module
diff --git a/flang/test/Evaluate/folding20.f90 b/flang/test/Evaluate/folding20.f90
index be012074fb4bda..cc8847230c3812 100644
--- a/flang/test/Evaluate/folding20.f90
+++ b/flang/test/Evaluate/folding20.f90
@@ -5,6 +5,10 @@ module m
   integer, parameter :: intmatrix(*,*) = reshape([1, 2, 3, 4, 5, 6], [2, 3])
   logical, parameter :: odds(2,3) = mod(intmatrix, 2) == 1
   character(*), parameter :: chmatrix(*,*) = reshape(['abc', 'def', 'ghi', 'jkl', 'mno', 'pqr'], [2, 3])
+  real, parameter :: nan = real(z'7FC00000'), inf = real(z'7F800000')
+  real, parameter :: nans(*) = [nan, nan]
+  real, parameter :: someNan(*) = [nan, 0.]
+  real, parameter :: someInf(*) = [inf, 0.]
 
   logical, parameter :: test_allidentity = all([Logical::])
   logical, parameter :: test_all = .not. all(odds)
@@ -43,6 +47,12 @@ module m
   logical, parameter :: test_rminval = minval(real(intmatrix)) == 1.0
   logical, parameter :: test_rmaxval_scalar_mask = maxval(real(intmatrix), .true.) == 6.0
   logical, parameter :: test_rminval_scalar_mask = minval(real(intmatrix), .false.) == huge(0.0)
+  logical, parameter :: test_rmaxval_allNaN = maxval(nans) /= maxval(nans)
+  logical, parameter :: test_rminval_allNaN = maxval(nans) /= maxval(nans)
+  logical, parameter :: test_rmaxval_someNaN = maxval(someNan) == 0.
+  logical, parameter :: test_rminval_someNaN = minval(someNan) == 0.
+  logical, parameter :: test_rmaxval_someInf = maxval(someInf) == inf
+  logical, parameter :: test_rminval_someInf = minval(-someInf) == -inf
   logical, parameter :: test_cmaxlen = len(maxval([character*4::])) == 4
   logical, parameter :: test_cmaxidentity = maxval([character*4::]) == repeat(char(0), 4)
   logical, parameter :: test_cminidentity = minval([character*4::]) == repeat(char(127), 4)



More information about the flang-commits mailing list