[flang-commits] [flang] 503c085 - [flang] Fold more reduction intrinsic function calls

peter klausler via flang-commits flang-commits at lists.llvm.org
Mon Jun 21 10:14:09 PDT 2021


Author: peter klausler
Date: 2021-06-21T10:13:59-07:00
New Revision: 503c085e3bcd0a031a363ee89c91b1f1e41bfa4b

URL: https://github.com/llvm/llvm-project/commit/503c085e3bcd0a031a363ee89c91b1f1e41bfa4b
DIFF: https://github.com/llvm/llvm-project/commit/503c085e3bcd0a031a363ee89c91b1f1e41bfa4b.diff

LOG: [flang] Fold more reduction intrinsic function calls

Refactor the recently-implemented MAXVAL/MINVAL folding so
that the parts that can be used to implement other reduction
transformational intrinsic function folding are exposed.

Use them to implement folding of IALL, IANY, IPARITY,
SUM. and PRODUCT.  Replace the folding of ALL & ANY to
use the new infrastructure and become able to handle DIM=
arguments.

Differential Revision: https://reviews.llvm.org/D104562

Added: 
    

Modified: 
    flang/include/flang/Evaluate/real.h
    flang/lib/Evaluate/fold-character.cpp
    flang/lib/Evaluate/fold-complex.cpp
    flang/lib/Evaluate/fold-integer.cpp
    flang/lib/Evaluate/fold-logical.cpp
    flang/lib/Evaluate/fold-real.cpp
    flang/lib/Evaluate/fold-reduction.h
    flang/test/Evaluate/folding20.f90

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Evaluate/real.h b/flang/include/flang/Evaluate/real.h
index 85928e8d1f8ea..9cd6f8305bd5c 100644
--- a/flang/include/flang/Evaluate/real.h
+++ b/flang/include/flang/Evaluate/real.h
@@ -55,6 +55,7 @@ class Real : public common::RealDetails<PREC> {
 
   constexpr Real() {} // +0.0
   constexpr Real(const Real &) = default;
+  constexpr Real(Real &&) = default;
   constexpr Real(const Word &bits) : word_{bits} {}
   constexpr Real &operator=(const Real &) = default;
   constexpr Real &operator=(Real &&) = default;

diff  --git a/flang/lib/Evaluate/fold-character.cpp b/flang/lib/Evaluate/fold-character.cpp
index 385159ed2d5b7..d1a459652209a 100644
--- a/flang/lib/Evaluate/fold-character.cpp
+++ b/flang/lib/Evaluate/fold-character.cpp
@@ -102,8 +102,8 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
           CharacterUtils<KIND>::TRIM(std::get<Scalar<T>>(*scalar))}};
     }
   }
-  // TODO: cshift, eoshift, maxloc, minloc, pack, reduce,
-  // spread, transfer, transpose, unpack
+  // TODO: cshift, eoshift, maxloc, minloc, pack, spread, transfer,
+  // transpose, unpack
   return Expr<T>{std::move(funcRef)};
 }
 

diff  --git a/flang/lib/Evaluate/fold-complex.cpp b/flang/lib/Evaluate/fold-complex.cpp
index 2e574aacc0d54..98949a47919f8 100644
--- a/flang/lib/Evaluate/fold-complex.cpp
+++ b/flang/lib/Evaluate/fold-complex.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "fold-implementation.h"
+#include "fold-reduction.h"
 
 namespace Fortran::evaluate {
 
@@ -15,6 +16,7 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
     FoldingContext &context,
     FunctionRef<Type<TypeCategory::Complex, KIND>> &&funcRef) {
   using T = Type<TypeCategory::Complex, KIND>;
+  using Part = typename T::Part;
   ActualArguments &args{funcRef.arguments()};
   auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)};
   CHECK(intrinsic);
@@ -40,7 +42,6 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
         return Fold(context, ConvertToType<T>(std::move(*x)));
       } else {
         // CMPLX(X [, Y [, KIND]]) with non-complex X
-        using Part = typename T::Part;
         Expr<SomeType> re{std::move(*args[0].value().UnwrapExpr())};
         Expr<SomeType> im{args.size() >= 2 && args[1].has_value()
                 ? std::move(*args[1]->UnwrapExpr())
@@ -53,9 +54,14 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
     }
   } else if (name == "merge") {
     return FoldMerge<T>(context, std::move(funcRef));
+  } else if (name == "product") {
+    auto one{Scalar<Part>::FromInteger(value::Integer<8>{1}).value};
+    return FoldProduct<T>(context, std::move(funcRef), Scalar<T>{one});
+  } else if (name == "sum") {
+    return FoldSum<T>(context, std::move(funcRef));
   }
-  // TODO: cshift, dot_product, eoshift, matmul, pack, product,
-  // reduce, spread, sum, transfer, transpose, unpack
+  // TODO: cshift, dot_product, eoshift, matmul, pack, spread, transfer,
+  // transpose, unpack
   return Expr<T>{std::move(funcRef)};
 }
 

diff  --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index 01f30b085a56b..1fddfe9d02bd7 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -174,6 +174,25 @@ Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
   return Expr<T>{std::move(funcRef)};
 }
 
+// for IALL, IANY, & IPARITY
+template <typename T>
+static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
+    Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
+    Scalar<T> identity) {
+  static_assert(T::category == TypeCategory::Integer);
+  using Element = Scalar<T>;
+  std::optional<ConstantSubscript> dim;
+  if (std::optional<Constant<T>> array{
+          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
+              /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
+    auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
+      element = (element.*operation)(array->At(at));
+    }};
+    return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
+  }
+  return Expr<T>{std::move(ref)};
+}
+
 template <int KIND>
 Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
     FoldingContext &context,
@@ -311,6 +330,12 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
     }
     return FoldElementalIntrinsic<T, T, T>(
         context, std::move(funcRef), ScalarFunc<T, T, T>(fptr));
+  } else if (name == "iall") {
+    return FoldBitReduction(
+        context, std::move(funcRef), &Scalar<T>::IAND, Scalar<T>{}.NOT());
+  } else if (name == "iany") {
+    return FoldBitReduction(
+        context, std::move(funcRef), &Scalar<T>::IOR, Scalar<T>{});
   } else if (name == "ibclr" || name == "ibset" || name == "ishft" ||
       name == "shifta" || name == "shiftr" || name == "shiftl") {
     // Second argument can be of any kind. However, it must be smaller or
@@ -393,6 +418,9 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
     } else {
       DIE("kind() result not integral");
     }
+  } else if (name == "iparity") {
+    return FoldBitReduction(
+        context, std::move(funcRef), &Scalar<T>::IEOR, Scalar<T>{});
   } else if (name == "lbound") {
     return LBOUND(context, std::move(funcRef));
   } else if (name == "leadz" || name == "trailz" || name == "poppar" ||
@@ -540,6 +568,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
           },
           cx->u)};
     }
+  } else if (name == "product") {
+    return FoldProduct<T>(context, std::move(funcRef), Scalar<T>{1});
   } else if (name == "radix") {
     return Expr<T>{2};
   } else if (name == "range") {
@@ -654,14 +684,15 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
             Fold(context, Expr<T>{8} * ConvertToType<T>(std::move(*bytes)))};
       }
     }
+  } else if (name == "sum") {
+    return FoldSum<T>(context, std::move(funcRef));
   } else if (name == "ubound") {
     return UBOUND(context, std::move(funcRef));
   }
   // TODO:
-  // cshift, dot_product, eoshift,
-  // findloc, iall, iany, iparity, ibits, image_status, ishftc,
-  // matmul, maxloc, minloc, pack, product, reduce,
-  // sign, spread, sum, transfer, transpose, unpack
+  // cshift, dot_product, eoshift, findloc, ibits, image_status, ishftc,
+  // matmul, maxloc, minloc, not, pack, sign, spread, transfer, transpose,
+  // unpack
   return Expr<T>{std::move(funcRef)};
 }
 

diff  --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp
index ca6dcdbcafe88..4af51373ec0dd 100644
--- a/flang/lib/Evaluate/fold-logical.cpp
+++ b/flang/lib/Evaluate/fold-logical.cpp
@@ -7,10 +7,30 @@
 //===----------------------------------------------------------------------===//
 
 #include "fold-implementation.h"
+#include "fold-reduction.h"
 #include "flang/Evaluate/check-expression.h"
 
 namespace Fortran::evaluate {
 
+// for ALL & ANY
+template <typename T>
+static Expr<T> FoldAllAny(FoldingContext &context, FunctionRef<T> &&ref,
+    Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
+    Scalar<T> identity) {
+  static_assert(T::category == TypeCategory::Logical);
+  using Element = Scalar<T>;
+  std::optional<ConstantSubscript> dim;
+  if (std::optional<Constant<T>> array{
+          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
+              /*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
+    auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
+      element = (element.*operation)(array->At(at));
+    }};
+    return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
+  }
+  return Expr<T>{std::move(ref)};
+}
+
 template <int KIND>
 Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
     FoldingContext &context,
@@ -21,31 +41,11 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
   CHECK(intrinsic);
   std::string name{intrinsic->name};
   if (name == "all") {
-    if (!args[1]) { // TODO: ALL(x,DIM=d)
-      if (const auto *constant{UnwrapConstantValue<T>(args[0])}) {
-        bool result{true};
-        for (const auto &element : constant->values()) {
-          if (!element.IsTrue()) {
-            result = false;
-            break;
-          }
-        }
-        return Expr<T>{result};
-      }
-    }
+    return FoldAllAny(
+        context, std::move(funcRef), &Scalar<T>::AND, Scalar<T>{true});
   } else if (name == "any") {
-    if (!args[1]) { // TODO: ANY(x,DIM=d)
-      if (const auto *constant{UnwrapConstantValue<T>(args[0])}) {
-        bool result{false};
-        for (const auto &element : constant->values()) {
-          if (element.IsTrue()) {
-            result = true;
-            break;
-          }
-        }
-        return Expr<T>{result};
-      }
-    }
+    return FoldAllAny(
+        context, std::move(funcRef), &Scalar<T>::OR, Scalar<T>{false});
   } else if (name == "associated") {
     bool gotConstant{true};
     const Expr<SomeType> *firstArgExpr{args[0]->UnwrapExpr()};
@@ -127,8 +127,8 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
   }
   // TODO: btest, cshift, dot_product, eoshift, is_iostat_end,
   // is_iostat_eor, lge, lgt, lle, llt, logical, matmul, out_of_range,
-  // pack, parity, reduce, spread, transfer, transpose, unpack,
-  // extends_type_of, same_type_as
+  // pack, parity, spread, transfer, transpose, unpack, extends_type_of,
+  // same_type_as
   return Expr<T>{std::move(funcRef)};
 }
 

diff  --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 88222689e4f3f..cf13899890969 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -120,6 +120,9 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
   } else if (name == "minval") {
     return FoldMaxvalMinval<T>(
         context, std::move(funcRef), RelationalOperator::LT, T::Scalar::HUGE());
+  } else if (name == "product") {
+    auto one{Scalar<T>::FromInteger(value::Integer<8>{1}).value};
+    return FoldProduct<T>(context, std::move(funcRef), one);
   } else if (name == "real") {
     if (auto *expr{args[0].value().UnwrapExpr()}) {
       return ToReal<KIND>(context, std::move(*expr));
@@ -127,14 +130,15 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
   } else if (name == "sign") {
     return FoldElementalIntrinsic<T, T, T>(
         context, std::move(funcRef), &Scalar<T>::SIGN);
+  } else if (name == "sum") {
+    return FoldSum<T>(context, std::move(funcRef));
   } else if (name == "tiny") {
     return Expr<T>{Scalar<T>::TINY()};
   }
   // TODO: cshift, dim, dot_product, eoshift, fraction, matmul,
-  // maxloc, minloc, modulo, nearest, norm2, pack, product,
-  // reduce, rrspacing, scale, set_exponent, spacing, spread,
-  // sum, transfer, transpose, unpack, bessel_jn (transformational) and
-  // bessel_yn (transformational)
+  // maxloc, minloc, modulo, nearest, norm2, pack, rrspacing, scale,
+  // set_exponent, spacing, spread, transfer, transpose, unpack,
+  // bessel_jn (transformational) and bessel_yn (transformational)
   return Expr<T>{std::move(funcRef)};
 }
 

diff  --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index 1c7473a1500a1..8793b00912925 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -16,122 +16,220 @@
 
 namespace Fortran::evaluate {
 
-// MAXVAL & MINVAL
+// Common preprocessing for reduction transformational intrinsic function
+// folding.  If the intrinsic can have DIM= &/or MASK= arguments, extract
+// and check them.  If a MASK= is present, apply it to the array data and
+// substitute identity values for elements corresponding to .FALSE. in
+// the mask.  If the result is present, the intrinsic call can be folded.
 template <typename T>
-Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
-    RelationalOperator opr, Scalar<T> identity) {
-  static_assert(T::category == TypeCategory::Integer ||
-      T::category == TypeCategory::Real ||
-      T::category == TypeCategory::Character);
-  using Element = typename Constant<T>::Element;
-  auto &arg{ref.arguments()};
-  CHECK(arg.size() <= 3);
+static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
+    ActualArguments &arg, std::optional<ConstantSubscript> &dim,
+    const Scalar<T> &identity, int arrayIndex,
+    std::optional<std::size_t> dimIndex = std::nullopt,
+    std::optional<std::size_t> maskIndex = std::nullopt) {
   if (arg.empty()) {
-    return Expr<T>{std::move(ref)};
+    return std::nullopt;
   }
-  Constant<T> *array{Folder<T>{context}.Folding(arg[0])};
-  if (!array || array->Rank() < 1) {
-    return Expr<T>{std::move(ref)};
+  Constant<T> *folded{Folder<T>{context}.Folding(arg[arrayIndex])};
+  if (!folded || folded->Rank() < 1) {
+    return std::nullopt;
   }
-  std::optional<ConstantSubscript> dim;
-  if (arg.size() >= 2 && arg[1]) {
-    if (auto *dimConst{Folder<SubscriptInteger>{context}.Folding(arg[1])}) {
+  if (dimIndex && arg.size() >= *dimIndex + 1 && arg[*dimIndex]) {
+    if (auto *dimConst{
+            Folder<SubscriptInteger>{context}.Folding(arg[*dimIndex])}) {
       if (auto dimScalar{dimConst->GetScalarValue()}) {
         dim.emplace(dimScalar->ToInt64());
-        if (*dim < 1 || *dim > array->Rank()) {
+        if (*dim < 1 || *dim > folded->Rank()) {
           context.messages().Say(
               "DIM=%jd is not valid for an array of rank %d"_err_en_US,
-              static_cast<std::intmax_t>(*dim), array->Rank());
+              static_cast<std::intmax_t>(*dim), folded->Rank());
           dim.reset();
         }
       }
     }
     if (!dim) {
-      return Expr<T>{std::move(ref)};
+      return std::nullopt;
     }
   }
-  Constant<LogicalResult> *mask{};
-  if (arg.size() >= 3 && arg[2]) {
-    mask = Folder<LogicalResult>{context}.Folding(arg[2]);
-    if (!mask) {
-      return Expr<T>{std::move(ref)};
-    }
-    if (!CheckConformance(context.messages(), AsShape(array->shape()),
-            AsShape(mask->shape()),
-            CheckConformanceFlags::RightScalarExpandable, "ARRAY=", "MASK=")
-             .value_or(false)) {
-      return Expr<T>{std::move(ref)};
-    }
-  }
-  // Do it
-  ConstantSubscripts at{array->lbounds()}, maskAt;
-  bool maskAllFalse{false};
-  if (mask) {
-    if (auto scalar{mask->GetScalarValue()}) {
-      if (scalar->IsTrue()) {
-        mask = nullptr; // all .TRUE.
+  if (maskIndex && arg.size() >= *maskIndex + 1 && arg[*maskIndex]) {
+    if (Constant<LogicalResult> *
+        mask{Folder<LogicalResult>{context}.Folding(arg[*maskIndex])}) {
+      if (CheckConformance(context.messages(), AsShape(folded->shape()),
+              AsShape(mask->shape()),
+              CheckConformanceFlags::RightScalarExpandable, "ARRAY=", "MASK=")
+              .value_or(false)) {
+        // Apply the mask in place to the array
+        std::size_t n{folded->size()};
+        std::vector<typename Constant<T>::Element> elements;
+        if (auto scalarMask{mask->GetScalarValue()}) {
+          if (scalarMask->IsTrue()) {
+            return Constant<T>{*folded};
+          } else { // MASK=.FALSE.
+            elements = std::vector<typename Constant<T>::Element>(n, identity);
+          }
+        } else { // mask is an array; test its elements
+          elements = std::vector<typename Constant<T>::Element>(n, identity);
+          ConstantSubscripts at{folded->lbounds()};
+          for (std::size_t j{0}; j < n; ++j, folded->IncrementSubscripts(at)) {
+            if (mask->values()[j].IsTrue()) {
+              elements[j] = folded->At(at);
+            }
+          }
+        }
+        if constexpr (T::category == TypeCategory::Character) {
+          return Constant<T>{static_cast<ConstantSubscript>(identity.size()),
+              std::move(elements), ConstantSubscripts{folded->shape()}};
+        } else {
+          return Constant<T>{
+              std::move(elements), ConstantSubscripts{folded->shape()}};
+        }
       } else {
-        maskAllFalse = true;
+        return std::nullopt;
       }
     } else {
-      maskAt = mask->lbounds();
+      return std::nullopt;
     }
+  } else {
+    return Constant<T>{*folded};
   }
-  std::vector<Element> result;
+}
+
+// Generalized reduction to an array of one dimension fewer (w/ DIM=)
+// or to a scalar (w/o DIM=).
+template <typename T, typename ACCUMULATOR>
+static Constant<T> DoReduction(const Constant<T> &array,
+    std::optional<ConstantSubscript> &dim, const Scalar<T> &identity,
+    ACCUMULATOR &accumulator) {
+  ConstantSubscripts at{array.lbounds()};
+  std::vector<typename Constant<T>::Element> elements;
   ConstantSubscripts resultShape; // empty -> scalar
-  // Internal function to accumulate into result.back().
-  auto Accumulate{[&]() {
-    if (!maskAllFalse && (maskAt.empty() || mask->At(maskAt).IsTrue())) {
-      Expr<LogicalResult> test{
-          PackageRelation(opr, Expr<T>{Constant<T>{array->At(at)}},
-              Expr<T>{Constant<T>{result.back()}})};
-      auto folded{GetScalarConstantValue<LogicalResult>(
-          test.Rewrite(context, std::move(test)))};
-      CHECK(folded.has_value());
-      if (folded->IsTrue()) {
-        result.back() = array->At(at);
-      }
-    }
-  }};
   if (dim) { // DIM= is present, so result is an array
-    resultShape = array->shape();
+    resultShape = array.shape();
     resultShape.erase(resultShape.begin() + (*dim - 1));
-    ConstantSubscript dimExtent{array->shape().at(*dim - 1)};
+    ConstantSubscript dimExtent{array.shape().at(*dim - 1)};
     ConstantSubscript &dimAt{at[*dim - 1]};
     ConstantSubscript dimLbound{dimAt};
-    ConstantSubscript *maskDimAt{maskAt.empty() ? nullptr : &maskAt[*dim - 1]};
-    ConstantSubscript maskLbound{maskDimAt ? *maskDimAt : 0};
     for (auto n{GetSize(resultShape)}; n-- > 0;
-         IncrementSubscripts(at, array->shape())) {
+         IncrementSubscripts(at, array.shape())) {
       dimAt = dimLbound;
-      if (maskDimAt) {
-        *maskDimAt = maskLbound;
-      }
-      result.push_back(identity);
-      for (ConstantSubscript j{0}; j < dimExtent;
-           ++j, ++dimAt, maskDimAt && ++*maskDimAt) {
-        Accumulate();
-      }
-      if (maskDimAt) {
-        IncrementSubscripts(maskAt, mask->shape());
+      elements.push_back(identity);
+      for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt) {
+        accumulator(elements.back(), at);
       }
     }
   } else { // no DIM=, result is scalar
-    result.push_back(identity);
-    for (auto n{array->size()}; n-- > 0;
-         IncrementSubscripts(at, array->shape())) {
-      Accumulate();
-      if (!maskAt.empty()) {
-        IncrementSubscripts(maskAt, mask->shape());
-      }
+    elements.push_back(identity);
+    for (auto n{array.size()}; n-- > 0;
+         IncrementSubscripts(at, array.shape())) {
+      accumulator(elements.back(), at);
     }
   }
   if constexpr (T::category == TypeCategory::Character) {
-    return Expr<T>{Constant<T>{static_cast<ConstantSubscript>(identity.size()),
-        std::move(result), std::move(resultShape)}};
+    return {static_cast<ConstantSubscript>(identity.size()),
+        std::move(elements), std::move(resultShape)};
   } else {
-    return Expr<T>{Constant<T>{std::move(result), std::move(resultShape)}};
+    return {std::move(elements), std::move(resultShape)};
+  }
+}
+
+// MAXVAL & MINVAL
+template <typename T>
+static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
+    RelationalOperator opr, const Scalar<T> &identity) {
+  static_assert(T::category == TypeCategory::Integer ||
+      T::category == TypeCategory::Real ||
+      T::category == TypeCategory::Character);
+  using Element = Scalar<T>; // pmk: was typename Constant<T>::Element;
+  std::optional<ConstantSubscript> dim;
+  if (std::optional<Constant<T>> array{
+          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
+              /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
+    auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
+      Expr<LogicalResult> test{PackageRelation(opr,
+          Expr<T>{Constant<T>{array->At(at)}}, Expr<T>{Constant<T>{element}})};
+      auto folded{GetScalarConstantValue<LogicalResult>(
+          test.Rewrite(context, std::move(test)))};
+      CHECK(folded.has_value());
+      if (folded->IsTrue()) {
+        element = array->At(at);
+      }
+    }};
+    return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
+  }
+  return Expr<T>{std::move(ref)};
+}
+
+// PRODUCT
+template <typename T>
+static Expr<T> FoldProduct(
+    FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) {
+  static_assert(T::category == TypeCategory::Integer ||
+      T::category == TypeCategory::Real ||
+      T::category == TypeCategory::Complex);
+  using Element = typename Constant<T>::Element;
+  std::optional<ConstantSubscript> dim;
+  if (std::optional<Constant<T>> array{
+          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
+              /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
+    bool overflow{false};
+    auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
+      if constexpr (T::category == TypeCategory::Integer) {
+        auto prod{element.MultiplySigned(array->At(at))};
+        overflow |= prod.SignedMultiplicationOverflowed();
+        element = prod.lower;
+      } else { // Real & Complex
+        auto prod{element.Multiply(array->At(at))};
+        overflow |= prod.flags.test(RealFlag::Overflow);
+        element = prod.value;
+      }
+    }};
+    if (overflow) {
+      context.messages().Say(
+          "PRODUCT() of %s data overflowed"_en_US, T::AsFortran());
+    } else {
+      return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
+    }
+  }
+  return Expr<T>{std::move(ref)};
+}
+
+// SUM
+template <typename T>
+static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
+  static_assert(T::category == TypeCategory::Integer ||
+      T::category == TypeCategory::Real ||
+      T::category == TypeCategory::Complex);
+  using Element = typename Constant<T>::Element;
+  std::optional<ConstantSubscript> dim;
+  Element identity{}, correction{};
+  if (std::optional<Constant<T>> array{
+          ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
+              /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
+    bool overflow{false};
+    auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
+      if constexpr (T::category == TypeCategory::Integer) {
+        auto sum{element.AddSigned(array->At(at))};
+        overflow |= sum.overflow;
+        element = sum.value;
+      } else { // Real & Complex: use Kahan summation
+        auto next{array->At(at).Add(correction)};
+        overflow |= next.flags.test(RealFlag::Overflow);
+        auto sum{element.Add(next.value)};
+        overflow |= sum.flags.test(RealFlag::Overflow);
+        // correction = (sum - element) - next; algebraically zero
+        correction =
+            sum.value.Subtract(element).value.Subtract(next.value).value;
+        element = sum.value;
+      }
+    }};
+    if (overflow) {
+      context.messages().Say(
+          "SUM() of %s data overflowed"_en_US, T::AsFortran());
+    } else {
+      return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
+    }
   }
+  return Expr<T>{std::move(ref)};
 }
 
 } // namespace Fortran::evaluate

diff  --git a/flang/test/Evaluate/folding20.f90 b/flang/test/Evaluate/folding20.f90
index dcb30d8598fd8..7d14e978dd838 100644
--- a/flang/test/Evaluate/folding20.f90
+++ b/flang/test/Evaluate/folding20.f90
@@ -1,13 +1,41 @@
 ! RUN: %S/test_folding.sh %s %t %flang_fc1
 ! REQUIRES: shell
-! Tests intrinsic MAXVAL/MINVAL function folding
+! Tests reduction intrinsic function folding
 module m
+  implicit none
+  integer, parameter :: intmatrix(*,*) = reshape([1, 2, 3, 4, 5, 6], [2, 3])
+  logical, parameter :: odds(2,3) = mod(intmatrix, 2) == 1
+  character(*), parameter :: chmatrix(*,*) = reshape(['abc', 'def', 'ghi', 'jkl', 'mno', 'pqr'], [2, 3])
+
+  logical, parameter :: test_allidentity = all([Logical::])
+  logical, parameter :: test_all = .not. all(odds)
+  logical, parameter :: test_alldim1 = all(.not. all(odds,1))
+  logical, parameter :: test_alldim2 = all(all(odds,2) .eqv. [.true., .false.])
+  logical, parameter :: test_anyidentity = .not. any([Logical::])
+  logical, parameter :: test_any = any(odds)
+  logical, parameter :: test_anydim1 = all(any(odds,1))
+  logical, parameter :: test_anydim2 = all(any(odds,2) .eqv. [.true., .false.])
+
+  logical, parameter :: test_iallidentity = iall([integer::]) == -1
+  logical, parameter :: test_iall = iall(intmatrix) == 0
+  logical, parameter :: test_iall_masked = iall(intmatrix,odds) == 1
+  logical, parameter :: test_ialldim1 = all(iall(intmatrix,dim=1) == [0, 0, 4])
+  logical, parameter :: test_ialldim2 = all(iall(intmatrix,dim=2) == [1, 0])
+  logical, parameter :: test_ianyidentity = iany([integer::]) == 0
+  logical, parameter :: test_iany = iany(intmatrix) == 7
+  logical, parameter :: test_iany_masked = iany(intmatrix,odds) == 7
+  logical, parameter :: test_ianydim1 = all(iany(intmatrix,dim=1) == [3, 7, 7])
+  logical, parameter :: test_ianydim2 = all(iany(intmatrix,dim=2) == [7, 6])
+  logical, parameter :: test_iparityidentity = iparity([integer::]) == 0
+  logical, parameter :: test_iparity = iparity(intmatrix) == 7
+  logical, parameter :: test_iparity_masked = iparity(intmatrix,odds) == 7
+  logical, parameter :: test_iparitydim1 = all(iparity(intmatrix,dim=1) == [3, 7, 3])
+  logical, parameter :: test_iparitydim2 = all(iparity(intmatrix,dim=2) == [7, 0])
+
   logical, parameter :: test_imaxidentity = maxval([integer::]) == -huge(0) - 1
   logical, parameter :: test_iminidentity = minval([integer::]) == huge(0)
-  integer, parameter :: intmatrix(*,*) = reshape([1, 2, 3, 4, 5, 6], [2, 3])
   logical, parameter :: test_imaxval = maxval(intmatrix) == 6
   logical, parameter :: test_iminval = minval(intmatrix) == 1
-  logical, parameter :: odds(2,3) = mod(intmatrix, 2) == 1
   logical, parameter :: test_imaxval_masked = maxval(intmatrix,odds) == 5
   logical, parameter :: test_iminval_masked = minval(intmatrix,.not.odds) == 2
   logical, parameter :: test_rmaxidentity = maxval([real::]) == -huge(0.0)
@@ -16,12 +44,31 @@ module m
   logical, parameter :: test_rminval = minval(real(intmatrix)) == 1.0
   logical, parameter :: test_rmaxval_scalar_mask = maxval(real(intmatrix), .true.) == 6.0
   logical, parameter :: test_rminval_scalar_mask = minval(real(intmatrix), .false.) == huge(0.0)
-  character(*), parameter :: chmatrix(*,*) = reshape(['abc', 'def', 'ghi', 'jkl', 'mno', 'pqr'], [2, 3])
   logical, parameter :: test_cmaxlen = len(maxval([character*4::])) == 4
   logical, parameter :: test_cmaxidentity = maxval([character*4::]) == repeat(char(0), 4)
   logical, parameter :: test_cminidentity = minval([character*4::]) == repeat(char(127), 4)
   logical, parameter :: test_cmaxval = maxval(chmatrix) == 'pqr'
   logical, parameter :: test_cminval = minval(chmatrix) == 'abc'
-  logical, parameter :: test_dim1 = all(maxval(intmatrix,dim=1) == [2, 4, 6])
-  logical, parameter :: test_dim2 = all(minval(intmatrix,dim=2,mask=odds) == [1, huge(0)])
+  logical, parameter :: test_maxvaldim1 = all(maxval(intmatrix,dim=1) == [2, 4, 6])
+  logical, parameter :: test_minvaldim2 = all(minval(intmatrix,dim=2,mask=odds) == [1, huge(0)])
+
+  logical, parameter :: test_iproductidentity = product([integer::]) == 1
+  logical, parameter :: test_iproduct = product(intmatrix) == 720
+  logical, parameter :: test_iproduct_masked = product(intmatrix,odds) == 15
+  logical, parameter :: test_productdim1 = all(product(intmatrix,dim=1) == [2, 12, 30])
+  logical, parameter :: test_productdim2 = all(product(intmatrix,dim=2) == [15, 48])
+  logical, parameter :: test_rproductidentity = product([real::]) == 1.
+  logical, parameter :: test_rproduct = product(real(intmatrix)) == 720.
+  logical, parameter :: test_cproductidentity = product([complex::]) == (1.,0.)
+  logical, parameter :: test_cproduct = product(cmplx(intmatrix,-intmatrix)) == (0.,5760.)
+
+  logical, parameter :: test_isumidentity = sum([integer::]) == 0
+  logical, parameter :: test_isum = sum(intmatrix) == 21
+  logical, parameter :: test_isum_masked = sum(intmatrix,odds) == 9
+  logical, parameter :: test_sumdim1 = all(sum(intmatrix,dim=1) == [3, 7, 11])
+  logical, parameter :: test_sumdim2 = all(sum(intmatrix,dim=2) == [9, 12])
+  logical, parameter :: test_rsumidentity = sum([real::]) == 0.
+  logical, parameter :: test_rsum = sum(real(intmatrix)) == 21.
+  logical, parameter :: test_csumidentity = sum([complex::]) == (0.,0.)
+  logical, parameter :: test_csum = sum(cmplx(intmatrix,-intmatrix)) == (21.,-21.)
 end


        


More information about the flang-commits mailing list