[flang-commits] [flang] [flang][runtime] Treatment of NaN in MAXVAL/MAXLOC/MINVAL/MINLOC (PR #76999)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Tue Jan 9 11:35:58 PST 2024


https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/76999

>From 6a0dec0f2fef121687a3bdf5a71de325c7e80b96 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Thu, 4 Jan 2024 10:26:00 -0800
Subject: [PATCH] [flang][runtime] Treatment of NaN in
 MAXVAL/MAXLOC/MINVAL/MINLOC

Detect NaN elements in data and handle them like gfortran does
(at runtime); namely, NaN can be returned if all the data are NaNs,
but any non-NaN value is preferable.  Ensure that folding returns
the same results as runtime computation.

Fixes llvm-test-suite/Fortran/gfortran/regression/maxloc_2.f90
(and probably others).
---
 flang/docs/Extensions.md              |   4 +
 flang/lib/Evaluate/fold-character.cpp |   2 +-
 flang/lib/Evaluate/fold-integer.cpp   |  60 +++++++-----
 flang/lib/Evaluate/fold-logical.cpp   |   9 +-
 flang/lib/Evaluate/fold-real.cpp      |  22 +++--
 flang/lib/Evaluate/fold-reduction.h   | 126 +++++++++++++++-----------
 flang/runtime/extrema.cpp             |  26 ++++--
 flang/runtime/reduction-templates.h   |   3 +-
 flang/test/Evaluate/fold-findloc.f90  |  12 +++
 flang/test/Evaluate/folding20.f90     |  12 ++-
 10 files changed, 173 insertions(+), 103 deletions(-)

diff --git a/flang/docs/Extensions.md b/flang/docs/Extensions.md
index 16eb67f2e27c81..94f4b31ec3822c 100644
--- a/flang/docs/Extensions.md
+++ b/flang/docs/Extensions.md
@@ -654,6 +654,10 @@ end
   we don't round.  This seems to be how the Intel Fortran compilers
   behave.
 
+* For real `MAXVAL`, `MINVAL`, `MAXLOC`, and `MINLOC`, NaN values are
+  essentially ignored unless there are some unmasked array entries and
+  *all* of them are NaNs.
+
 ## De Facto Standard Features
 
 * `EXTENDS_TYPE_OF()` returns `.TRUE.` if both of its arguments have the
diff --git a/flang/lib/Evaluate/fold-character.cpp b/flang/lib/Evaluate/fold-character.cpp
index a599815fa7aee9..5d9cc11754a7dd 100644
--- a/flang/lib/Evaluate/fold-character.cpp
+++ b/flang/lib/Evaluate/fold-character.cpp
@@ -84,7 +84,7 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
     return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
   } else if (name == "minval") {
     // Collating sequences correspond to positive integers (3.31)
-    SingleCharType most{0x7fffffff >> (8 * (4 - KIND))};
+    auto most{static_cast<SingleCharType>(0xffffffff >> (8 * (4 - KIND)))};
     if (auto identity{Identity<T>(
             StringType{most}, GetConstantLength(context, funcRef, 0))}) {
       return FoldMaxvalMinval<T>(
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index ba4bc6a04750ff..0e8706e0f27405 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -270,7 +270,8 @@ template <typename T, int MASK_KIND> class CountAccumulator {
 
 public:
   CountAccumulator(const Constant<MaskT> &mask) : mask_{mask} {}
-  void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
+  void operator()(
+      Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
     if (mask_.At(at).IsTrue()) {
       auto incremented{element.AddSigned(Scalar<T>{1})};
       overflow_ |= incremented.overflow;
@@ -287,22 +288,20 @@ template <typename T, int MASK_KIND> class CountAccumulator {
 
 template <typename T, int maskKind>
 static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
-  using LogicalResult = Type<TypeCategory::Logical, maskKind>;
+  using KindLogical = Type<TypeCategory::Logical, maskKind>;
   static_assert(T::category == TypeCategory::Integer);
-  ActualArguments &arg{ref.arguments()};
-  if (const Constant<LogicalResult> *mask{arg.empty()
-              ? nullptr
-              : Folder<LogicalResult>{context}.Folding(arg[0])}) {
-    std::optional<int> dim;
-    if (CheckReductionDIM(dim, context, arg, 1, mask->Rank())) {
-      CountAccumulator<T, maskKind> accumulator{*mask};
-      Constant<T> result{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
-      if (accumulator.overflow()) {
-        context.messages().Say(
-            "Result of intrinsic function COUNT overflows its result type"_warn_en_US);
-      }
-      return Expr<T>{std::move(result)};
+  std::optional<int> dim;
+  if (std::optional<ArrayAndMask<KindLogical>> arrayAndMask{
+          ProcessReductionArgs<KindLogical>(
+              context, ref.arguments(), dim, /*ARRAY=*/0, /*DIM=*/1)}) {
+    CountAccumulator<T, maskKind> accumulator{arrayAndMask->array};
+    Constant<T> result{DoReduction<T>(arrayAndMask->array, arrayAndMask->mask,
+        dim, Scalar<T>{}, accumulator)};
+    if (accumulator.overflow()) {
+      context.messages().Say(
+          "Result of intrinsic function COUNT overflows its result type"_warn_en_US);
     }
+    return Expr<T>{std::move(result)};
   }
   return Expr<T>{std::move(ref)};
 }
@@ -395,7 +394,7 @@ template <WhichLocation WHICH> class LocationHelper {
         for (ConstantSubscript k{0}; k < dimLength;
              ++k, ++at[zbDim], mask && ++maskAt[zbDim]) {
           if ((!mask || mask->At(maskAt).IsTrue()) &&
-              IsHit(array->At(at), value, relation)) {
+              IsHit(array->At(at), value, relation, back)) {
             hit = at[zbDim];
             if constexpr (WHICH == WhichLocation::Findloc) {
               if (!back) {
@@ -422,7 +421,7 @@ template <WhichLocation WHICH> class LocationHelper {
       for (ConstantSubscript j{0}; j < n; ++j, array->IncrementSubscripts(at),
            mask && mask->IncrementSubscripts(maskAt)) {
         if ((!mask || mask->At(maskAt).IsTrue()) &&
-            IsHit(array->At(at), value, relation)) {
+            IsHit(array->At(at), value, relation, back)) {
           resultIndices = at;
           if constexpr (WHICH == WhichLocation::Findloc) {
             if (!back) {
@@ -444,7 +443,8 @@ template <WhichLocation WHICH> class LocationHelper {
   template <typename T>
   bool IsHit(typename Constant<T>::Element element,
       std::optional<Constant<T>> &value,
-      [[maybe_unused]] RelationalOperator relation) const {
+      [[maybe_unused]] RelationalOperator relation,
+      [[maybe_unused]] bool back) const {
     std::optional<Expr<LogicalResult>> cmp;
     bool result{true};
     if (value) {
@@ -455,8 +455,19 @@ template <WhichLocation WHICH> class LocationHelper {
             Expr<T>{LogicalOperation<T::kind>{LogicalOperator::Eqv,
                 Expr<T>{Constant<T>{element}}, Expr<T>{Constant<T>{*value}}}}));
       } else { // compare array(at) to value
-        cmp.emplace(PackageRelation(relation, Expr<T>{Constant<T>{element}},
-            Expr<T>{Constant<T>{*value}}));
+        if constexpr (T::category == TypeCategory::Real &&
+            (WHICH == WhichLocation::Maxloc ||
+                WHICH == WhichLocation::Minloc)) {
+          if (value && value->GetScalarValue().value().IsNotANumber() &&
+              (back || !element.IsNotANumber())) {
+            // Replace NaN
+            cmp.emplace(Constant<LogicalResult>{Scalar<LogicalResult>{true}});
+          }
+        }
+        if (!cmp) {
+          cmp.emplace(PackageRelation(relation, Expr<T>{Constant<T>{element}},
+              Expr<T>{Constant<T>{*value}}));
+        }
       }
       Expr<LogicalResult> folded{Fold(context_, std::move(*cmp))};
       result = GetScalarConstantValue<LogicalResult>(folded).value().IsTrue();
@@ -523,11 +534,12 @@ static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
     Scalar<T> identity) {
   static_assert(T::category == TypeCategory::Integer);
   std::optional<int> dim;
-  if (std::optional<Constant<T>> array{
-          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
+  if (std::optional<ArrayAndMask<T>> arrayAndMask{
+          ProcessReductionArgs<T>(context, ref.arguments(), dim,
               /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
-    OperationAccumulator<T> accumulator{*array, operation};
-    return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
+    OperationAccumulator<T> accumulator{arrayAndMask->array, operation};
+    return Expr<T>{DoReduction<T>(
+        arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)};
   }
   return Expr<T>{std::move(ref)};
 }
diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp
index 82a5cb20db9e40..5a9596f3c274b5 100644
--- a/flang/lib/Evaluate/fold-logical.cpp
+++ b/flang/lib/Evaluate/fold-logical.cpp
@@ -31,11 +31,12 @@ static Expr<T> FoldAllAnyParity(FoldingContext &context, FunctionRef<T> &&ref,
     Scalar<T> identity) {
   static_assert(T::category == TypeCategory::Logical);
   std::optional<int> dim;
-  if (std::optional<Constant<T>> array{
-          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
+  if (std::optional<ArrayAndMask<T>> arrayAndMask{
+          ProcessReductionArgs<T>(context, ref.arguments(), dim,
               /*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
-    OperationAccumulator accumulator{*array, operation};
-    return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
+    OperationAccumulator accumulator{arrayAndMask->array, operation};
+    return Expr<T>{DoReduction<T>(
+        arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)};
   }
   return Expr<T>{std::move(ref)};
 }
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 6ae069df5d7a42..fd37437c643aa2 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -52,7 +52,8 @@ template <int KIND> class Norm2Accumulator {
   Norm2Accumulator(
       const Constant<T> &array, const Constant<T> &maxAbs, Rounding rounding)
       : array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
-  void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
+  void operator()(
+      Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
     // Kahan summation of scaled elements:
     // Naively,
     //   NORM2(A(:)) = SQRT(SUM(A(:)**2))
@@ -114,17 +115,18 @@ static Expr<Type<TypeCategory::Real, KIND>> FoldNorm2(FoldingContext &context,
   using T = Type<TypeCategory::Real, KIND>;
   using Element = typename Constant<T>::Element;
   std::optional<int> dim;
-  const Element identity{};
-  if (std::optional<Constant<T>> array{
-          ProcessReductionArgs<T>(context, funcRef.arguments(), dim, identity,
+  if (std::optional<ArrayAndMask<T>> arrayAndMask{
+          ProcessReductionArgs<T>(context, funcRef.arguments(), dim,
               /*X=*/0, /*DIM=*/1)}) {
     MaxvalMinvalAccumulator<T, /*ABS=*/true> maxAbsAccumulator{
-        RelationalOperator::GT, context, *array};
-    Constant<T> maxAbs{
-        DoReduction<T>(*array, dim, identity, maxAbsAccumulator)};
-    Norm2Accumulator norm2Accumulator{
-        *array, maxAbs, context.targetCharacteristics().roundingMode()};
-    Constant<T> result{DoReduction<T>(*array, dim, identity, norm2Accumulator)};
+        RelationalOperator::GT, context, arrayAndMask->array};
+    const Element identity{};
+    Constant<T> maxAbs{DoReduction<T>(arrayAndMask->array, arrayAndMask->mask,
+        dim, identity, maxAbsAccumulator)};
+    Norm2Accumulator norm2Accumulator{arrayAndMask->array, maxAbs,
+        context.targetCharacteristics().roundingMode()};
+    Constant<T> result{DoReduction<T>(arrayAndMask->array, arrayAndMask->mask,
+        dim, identity, norm2Accumulator)};
     if (norm2Accumulator.overflow()) {
       context.messages().Say(
           "NORM2() of REAL(%d) data overflowed"_warn_en_US, KIND);
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index 60c757dc3f4fa8..1ee957c0faebd8 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -116,11 +116,15 @@ Constant<LogicalResult> *GetReductionMASK(
 // Common preprocessing for reduction transformational intrinsic function
 // folding.  If the intrinsic can have DIM= &/or MASK= arguments, extract
 // and check them.  If a MASK= is present, apply it to the array data and
-// substitute identity values for elements corresponding to .FALSE. in
+// substitute replacement values for elements corresponding to .FALSE. in
 // the mask.  If the result is present, the intrinsic call can be folded.
+template <typename T> struct ArrayAndMask {
+  Constant<T> array;
+  Constant<LogicalResult> mask;
+};
 template <typename T>
-static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
-    ActualArguments &arg, std::optional<int> &dim, const Scalar<T> &identity,
+static std::optional<ArrayAndMask<T>> ProcessReductionArgs(
+    FoldingContext &context, ActualArguments &arg, std::optional<int> &dim,
     int arrayIndex, std::optional<int> dimIndex = std::nullopt,
     std::optional<int> maskIndex = std::nullopt) {
   if (arg.empty()) {
@@ -133,73 +137,74 @@ static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
   if (!CheckReductionDIM(dim, context, arg, dimIndex, folded->Rank())) {
     return std::nullopt;
   }
+  std::size_t n{folded->size()};
+  std::vector<Scalar<LogicalResult>> maskElement;
   if (maskIndex && static_cast<std::size_t>(*maskIndex) < arg.size() &&
       arg[*maskIndex]) {
-    if (const Constant<LogicalResult> *mask{
+    if (const Constant<LogicalResult> *origMask{
             GetReductionMASK(arg[*maskIndex], folded->shape(), context)}) {
-      // Apply the mask in place to the array
-      std::size_t n{folded->size()};
-      std::vector<typename Constant<T>::Element> elements;
-      if (auto scalarMask{mask->GetScalarValue()}) {
-        if (scalarMask->IsTrue()) {
-          return Constant<T>{*folded};
-        } else { // MASK=.FALSE.
-          elements = std::vector<typename Constant<T>::Element>(n, identity);
-        }
-      } else { // mask is an array; test its elements
-        elements = std::vector<typename Constant<T>::Element>(n, identity);
-        ConstantSubscripts at{folded->lbounds()};
-        for (std::size_t j{0}; j < n; ++j, folded->IncrementSubscripts(at)) {
-          if (mask->values()[j].IsTrue()) {
-            elements[j] = folded->At(at);
-          }
-        }
-      }
-      if constexpr (T::category == TypeCategory::Character) {
-        return Constant<T>{static_cast<ConstantSubscript>(identity.size()),
-            std::move(elements), ConstantSubscripts{folded->shape()}};
+      if (auto scalarMask{origMask->GetScalarValue()}) {
+        maskElement =
+            std::vector<Scalar<LogicalResult>>(n, scalarMask->IsTrue());
       } else {
-        return Constant<T>{
-            std::move(elements), ConstantSubscripts{folded->shape()}};
+        maskElement = origMask->values();
       }
     } else {
       return std::nullopt;
     }
   } else {
-    return Constant<T>{*folded};
+    maskElement = std::vector<Scalar<LogicalResult>>(n, true);
   }
+  return ArrayAndMask<T>{Constant<T>(*folded),
+      Constant<LogicalResult>{
+          std::move(maskElement), ConstantSubscripts{folded->shape()}}};
 }
 
 // Generalized reduction to an array of one dimension fewer (w/ DIM=)
 // or to a scalar (w/o DIM=).  The ACCUMULATOR type must define
-// operator()(Scalar<T> &, const ConstantSubscripts &) and Done(Scalar<T> &).
+// operator()(Scalar<T> &, const ConstantSubscripts &, bool first)
+// and Done(Scalar<T> &).
 template <typename T, typename ACCUMULATOR, typename ARRAY>
 static Constant<T> DoReduction(const Constant<ARRAY> &array,
-    std::optional<int> &dim, const Scalar<T> &identity,
-    ACCUMULATOR &accumulator) {
+    const Constant<LogicalResult> &mask, std::optional<int> &dim,
+    const Scalar<T> &identity, ACCUMULATOR &accumulator) {
   ConstantSubscripts at{array.lbounds()};
+  ConstantSubscripts maskAt{mask.lbounds()};
   std::vector<typename Constant<T>::Element> elements;
   ConstantSubscripts resultShape; // empty -> scalar
   if (dim) { // DIM= is present, so result is an array
     resultShape = array.shape();
     resultShape.erase(resultShape.begin() + (*dim - 1));
     ConstantSubscript dimExtent{array.shape().at(*dim - 1)};
+    CHECK(dimExtent == mask.shape().at(*dim - 1));
     ConstantSubscript &dimAt{at[*dim - 1]};
     ConstantSubscript dimLbound{dimAt};
+    ConstantSubscript &maskDimAt{maskAt[*dim - 1]};
+    ConstantSubscript maskDimLbound{maskDimAt};
     for (auto n{GetSize(resultShape)}; n-- > 0;
-         IncrementSubscripts(at, array.shape())) {
+         IncrementSubscripts(at, array.shape()),
+         IncrementSubscripts(maskAt, mask.shape())) {
       dimAt = dimLbound;
+      maskDimAt = maskDimLbound;
       elements.push_back(identity);
-      for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt) {
-        accumulator(elements.back(), at);
+      bool firstUnmasked{true};
+      for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt, ++maskDimAt) {
+        if (mask.At(maskAt).IsTrue()) {
+          accumulator(elements.back(), at, firstUnmasked);
+          firstUnmasked = false;
+        }
       }
       accumulator.Done(elements.back());
     }
   } else { // no DIM=, result is scalar
     elements.push_back(identity);
-    for (auto n{array.size()}; n-- > 0;
-         IncrementSubscripts(at, array.shape())) {
-      accumulator(elements.back(), at);
+    bool firstUnmasked{true};
+    for (auto n{array.size()}; n-- > 0; IncrementSubscripts(at, array.shape()),
+         IncrementSubscripts(maskAt, mask.shape())) {
+      if (mask.At(maskAt).IsTrue()) {
+        accumulator(elements.back(), at, firstUnmasked);
+        firstUnmasked = false;
+      }
     }
     accumulator.Done(elements.back());
   }
@@ -217,11 +222,20 @@ template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
   MaxvalMinvalAccumulator(
       RelationalOperator opr, FoldingContext &context, const Constant<T> &array)
       : opr_{opr}, context_{context}, array_{array} {};
-  void operator()(Scalar<T> &element, const ConstantSubscripts &at) const {
+  void operator()(Scalar<T> &element, const ConstantSubscripts &at,
+      [[maybe_unused]] bool firstUnmasked) const {
     auto aAt{array_.At(at)};
     if constexpr (ABS) {
       aAt = aAt.ABS();
     }
+    if constexpr (T::category == TypeCategory::Real) {
+      if (firstUnmasked || element.IsNotANumber()) {
+        // Return NaN if and only if all unmasked elements are NaNs and
+        // at least one unmasked element is visible.
+        element = aAt;
+        return;
+      }
+    }
     Expr<LogicalResult> test{PackageRelation(
         opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})};
     auto folded{GetScalarConstantValue<LogicalResult>(
@@ -246,11 +260,12 @@ static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
       T::category == TypeCategory::Real ||
       T::category == TypeCategory::Character);
   std::optional<int> dim;
-  if (std::optional<Constant<T>> array{
-          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
+  if (std::optional<ArrayAndMask<T>> arrayAndMask{
+          ProcessReductionArgs<T>(context, ref.arguments(), dim,
               /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
-    MaxvalMinvalAccumulator accumulator{opr, context, *array};
-    return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
+    MaxvalMinvalAccumulator accumulator{opr, context, arrayAndMask->array};
+    return Expr<T>{DoReduction<T>(
+        arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)};
   }
   return Expr<T>{std::move(ref)};
 }
@@ -259,7 +274,8 @@ static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
 template <typename T> class ProductAccumulator {
 public:
   ProductAccumulator(const Constant<T> &array) : array_{array} {}
-  void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
+  void operator()(
+      Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
     if constexpr (T::category == TypeCategory::Integer) {
       auto prod{element.MultiplySigned(array_.At(at))};
       overflow_ |= prod.SignedMultiplicationOverflowed();
@@ -285,11 +301,12 @@ static Expr<T> FoldProduct(
       T::category == TypeCategory::Real ||
       T::category == TypeCategory::Complex);
   std::optional<int> dim;
-  if (std::optional<Constant<T>> array{
-          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
+  if (std::optional<ArrayAndMask<T>> arrayAndMask{
+          ProcessReductionArgs<T>(context, ref.arguments(), dim,
               /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
-    ProductAccumulator accumulator{*array};
-    auto result{Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)}};
+    ProductAccumulator accumulator{arrayAndMask->array};
+    auto result{Expr<T>{DoReduction<T>(
+        arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
     if (accumulator.overflow()) {
       context.messages().Say(
           "PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran());
@@ -306,7 +323,8 @@ template <typename T> class SumAccumulator {
 public:
   SumAccumulator(const Constant<T> &array, Rounding rounding)
       : array_{array}, rounding_{rounding} {}
-  void operator()(Element &element, const ConstantSubscripts &at) {
+  void operator()(
+      Element &element, const ConstantSubscripts &at, bool /*first*/) {
     if constexpr (T::category == TypeCategory::Integer) {
       auto sum{element.AddSigned(array_.At(at))};
       overflow_ |= sum.overflow;
@@ -348,12 +366,13 @@ static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
   using Element = typename Constant<T>::Element;
   std::optional<int> dim;
   Element identity{};
-  if (std::optional<Constant<T>> array{
-          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
+  if (std::optional<ArrayAndMask<T>> arrayAndMask{
+          ProcessReductionArgs<T>(context, ref.arguments(), dim,
               /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
     SumAccumulator accumulator{
-        *array, context.targetCharacteristics().roundingMode()};
-    auto result{Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)}};
+        arrayAndMask->array, context.targetCharacteristics().roundingMode()};
+    auto result{Expr<T>{DoReduction<T>(
+        arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)}};
     if (accumulator.overflow()) {
       context.messages().Say(
           "SUM() of %s data overflowed"_warn_en_US, T::AsFortran());
@@ -369,7 +388,8 @@ template <typename T> class OperationAccumulator {
   OperationAccumulator(const Constant<T> &array,
       Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const)
       : array_{array}, operation_{operation} {}
-  void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
+  void operator()(
+      Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
     element = (element.*operation_)(array_.At(at));
   }
   void Done(Scalar<T> &) const {}
diff --git a/flang/runtime/extrema.cpp b/flang/runtime/extrema.cpp
index edb5d5f47a5acf..8840b855ee3ff8 100644
--- a/flang/runtime/extrema.cpp
+++ b/flang/runtime/extrema.cpp
@@ -19,6 +19,7 @@
 #include <cinttypes>
 #include <cmath>
 #include <optional>
+#include <type_traits>
 
 namespace Fortran::runtime {
 
@@ -28,7 +29,9 @@ template <typename T, bool IS_MAX, bool BACK> struct NumericCompare {
   using Type = T;
   explicit RT_API_ATTRS NumericCompare(std::size_t /*elemLen; ignored*/) {}
   RT_API_ATTRS bool operator()(const T &value, const T &previous) const {
-    if (value == previous) {
+    if (std::is_floating_point_v<T> && previous != previous) {
+      return BACK || value == value; // replace NaN
+    } else if (value == previous) {
       return BACK;
     } else if constexpr (IS_MAX) {
       return value > previous;
@@ -76,11 +79,10 @@ template <typename COMPARE> class ExtremumLocAccumulator {
   template <typename A>
   RT_API_ATTRS void GetResult(A *p, int zeroBasedDim = -1) {
     if (zeroBasedDim >= 0) {
-      *p = extremumLoc_[zeroBasedDim] -
-          array_.GetDimension(zeroBasedDim).LowerBound() + 1;
+      *p = extremumLoc_[zeroBasedDim];
     } else {
       for (int j{0}; j < argRank_; ++j) {
-        p[j] = extremumLoc_[j] - array_.GetDimension(j).LowerBound() + 1;
+        p[j] = extremumLoc_[j];
       }
     }
   }
@@ -90,7 +92,7 @@ template <typename COMPARE> class ExtremumLocAccumulator {
     if (!previous_ || compare_(value, *previous_)) {
       previous_ = &value;
       for (int j{0}; j < argRank_; ++j) {
-        extremumLoc_[j] = at[j];
+        extremumLoc_[j] = at[j] - array_.GetDimension(j).LowerBound() + 1;
       }
     }
     return true;
@@ -485,6 +487,7 @@ class NumericExtremumAccumulator {
   explicit RT_API_ATTRS NumericExtremumAccumulator(const Descriptor &array)
       : array_{array} {}
   RT_API_ATTRS void Reinitialize() {
+    any_ = false;
     extremum_ = MaxOrMinIdentity<CAT, KIND, IS_MAXVAL>::Value();
   }
   template <typename A>
@@ -492,7 +495,12 @@ class NumericExtremumAccumulator {
     *p = extremum_;
   }
   RT_API_ATTRS bool Accumulate(Type x) {
-    if constexpr (IS_MAXVAL) {
+    if (!any_) {
+      extremum_ = x;
+      any_ = true;
+    } else if (CAT == TypeCategory::Real && extremum_ != extremum_) {
+      extremum_ = x; // replace NaN
+    } else if constexpr (IS_MAXVAL) {
       if (x > extremum_) {
         extremum_ = x;
       }
@@ -508,6 +516,7 @@ class NumericExtremumAccumulator {
 
 private:
   const Descriptor &array_;
+  bool any_{false};
   Type extremum_{MaxOrMinIdentity<CAT, KIND, IS_MAXVAL>::Value()};
 };
 
@@ -598,9 +607,8 @@ template <int KIND, bool IS_MAXVAL> class CharacterExtremumAccumulator {
       std::memcpy(p, extremum_, byteSize);
     } else {
       // Empty array; fill with character 0 for MAXVAL.
-      // For MINVAL, fill with 127 if ASCII as required
-      // by the standard, otherwise set all of the bits.
-      std::memset(p, IS_MAXVAL ? 0 : KIND == 1 ? 127 : 255, byteSize);
+      // For MINVAL, set all of the bits.
+      std::memset(p, IS_MAXVAL ? 0 : 255, byteSize);
     }
   }
   RT_API_ATTRS bool Accumulate(const Type *x) {
diff --git a/flang/runtime/reduction-templates.h b/flang/runtime/reduction-templates.h
index cf1ee8a967750d..7d0f82d59a084d 100644
--- a/flang/runtime/reduction-templates.h
+++ b/flang/runtime/reduction-templates.h
@@ -57,8 +57,9 @@ inline RT_API_ATTRS void DoTotalReduction(const Descriptor &x, int dim,
       for (auto elements{x.Elements()}; elements--;
            x.IncrementSubscripts(xAt), mask->IncrementSubscripts(maskAt)) {
         if (IsLogicalElementTrue(*mask, maskAt)) {
-          if (!accumulator.template AccumulateAt<TYPE>(xAt))
+          if (!accumulator.template AccumulateAt<TYPE>(xAt)) {
             break;
+          }
         }
       }
       return;
diff --git a/flang/test/Evaluate/fold-findloc.f90 b/flang/test/Evaluate/fold-findloc.f90
index b8bb85af65dc6c..0da74fddfdfb26 100644
--- a/flang/test/Evaluate/fold-findloc.f90
+++ b/flang/test/Evaluate/fold-findloc.f90
@@ -4,6 +4,9 @@ module m1
   integer, parameter :: ia1(2:6) = [1, 2, 3, 2, 1]
   integer, parameter :: ia2(2:3,2:4) = reshape([1, 2, 3, 3, 2, 1], shape(ia2))
   integer, parameter :: ia3(2,0,2) = 0 ! middle dimension has zero extent
+  real, parameter :: nan = real(z'7FC00000')
+  real, parameter :: nans(*) = [nan, nan]
+  real, parameter :: someNans(*) = [nan, 0.]
 
   logical, parameter :: test_fi1a = all(findloc(ia1, 1) == 1)
   logical, parameter :: test_fi1ar = rank(findloc(ia1, 1)) == 1
@@ -85,4 +88,13 @@ module m1
   logical, parameter:: test_fia1_mfd = all(findloc(ia1, 1, mask=.false., dim=1) == [0])
   logical, parameter:: test_fia2_mfd1 = all(findloc(ia2, 1, dim=1, mask=.false.) == [0, 0, 0])
   logical, parameter:: test_fia2_mfd2 = all(findloc(ia2, 1, dim=2, mask=.false.) == [0, 0])
+
+  logical, parameter :: test_nan1 = maxloc(nans,1) == 1
+  logical, parameter :: test_nan2 = maxloc(nans,1,back=.true.) == 2
+  logical, parameter :: test_nan3 = minloc(nans,1) == 1
+  logical, parameter :: test_nan4 = minloc(nans,1,back=.true.) == 2
+  logical, parameter :: test_nan5 = maxloc(someNans,1) == 2
+  logical, parameter :: test_nan6 = maxloc(someNans,1,back=.true.) == 2
+  logical, parameter :: test_nan7 = minloc(someNans,1) == 2
+  logical, parameter :: test_nan8 = minloc(someNans,1,back=.true.) == 2
 end module
diff --git a/flang/test/Evaluate/folding20.f90 b/flang/test/Evaluate/folding20.f90
index be012074fb4bda..da3c1c8137918b 100644
--- a/flang/test/Evaluate/folding20.f90
+++ b/flang/test/Evaluate/folding20.f90
@@ -5,6 +5,10 @@ module m
   integer, parameter :: intmatrix(*,*) = reshape([1, 2, 3, 4, 5, 6], [2, 3])
   logical, parameter :: odds(2,3) = mod(intmatrix, 2) == 1
   character(*), parameter :: chmatrix(*,*) = reshape(['abc', 'def', 'ghi', 'jkl', 'mno', 'pqr'], [2, 3])
+  real, parameter :: nan = real(z'7FC00000'), inf = real(z'7F800000')
+  real, parameter :: nans(*) = [nan, nan]
+  real, parameter :: someNan(*) = [nan, 0.]
+  real, parameter :: someInf(*) = [inf, 0.]
 
   logical, parameter :: test_allidentity = all([Logical::])
   logical, parameter :: test_all = .not. all(odds)
@@ -43,9 +47,15 @@ module m
   logical, parameter :: test_rminval = minval(real(intmatrix)) == 1.0
   logical, parameter :: test_rmaxval_scalar_mask = maxval(real(intmatrix), .true.) == 6.0
   logical, parameter :: test_rminval_scalar_mask = minval(real(intmatrix), .false.) == huge(0.0)
+  logical, parameter :: test_rmaxval_allNaN = maxval(nans) /= maxval(nans)
+  logical, parameter :: test_rminval_allNaN = maxval(nans) /= maxval(nans)
+  logical, parameter :: test_rmaxval_someNaN = maxval(someNan) == 0.
+  logical, parameter :: test_rminval_someNaN = minval(someNan) == 0.
+  logical, parameter :: test_rmaxval_someInf = maxval(someInf) == inf
+  logical, parameter :: test_rminval_someInf = minval(-someInf) == -inf
   logical, parameter :: test_cmaxlen = len(maxval([character*4::])) == 4
   logical, parameter :: test_cmaxidentity = maxval([character*4::]) == repeat(char(0), 4)
-  logical, parameter :: test_cminidentity = minval([character*4::]) == repeat(char(127), 4)
+  logical, parameter :: test_cminidentity = minval([character*4::]) == repeat(char(255), 4)
   logical, parameter :: test_cmaxval = maxval(chmatrix) == 'pqr'
   logical, parameter :: test_cminval = minval(chmatrix) == 'abc'
   logical, parameter :: test_maxvaldim1 = all(maxval(intmatrix,dim=1) == [2, 4, 6])



More information about the flang-commits mailing list