[flang-commits] [flang] e723c69 - [flang] Fold DOT_PRODUCT()

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Thu Aug 25 16:52:41 PDT 2022


Author: Peter Klausler
Date: 2022-08-25T16:52:21-07:00
New Revision: e723c69b94b9ac9c6977c0df011ee6219e67da4d

URL: https://github.com/llvm/llvm-project/commit/e723c69b94b9ac9c6977c0df011ee6219e67da4d
DIFF: https://github.com/llvm/llvm-project/commit/e723c69b94b9ac9c6977c0df011ee6219e67da4d.diff

LOG: [flang] Fold DOT_PRODUCT()

Implement constant folding of the intrinsic function DOT_PRODUCT().

Differential Revision: https://reviews.llvm.org/D132688

Added: 
    flang/test/Evaluate/fold-dot.f90

Modified: 
    flang/include/flang/Evaluate/fold.h
    flang/lib/Evaluate/fold-complex.cpp
    flang/lib/Evaluate/fold-integer.cpp
    flang/lib/Evaluate/fold-logical.cpp
    flang/lib/Evaluate/fold-real.cpp
    flang/lib/Evaluate/fold-reduction.h

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Evaluate/fold.h b/flang/include/flang/Evaluate/fold.h
index e7081a06dddb2..24fb54761962c 100644
--- a/flang/include/flang/Evaluate/fold.h
+++ b/flang/include/flang/Evaluate/fold.h
@@ -57,10 +57,8 @@ auto UnwrapConstantValue(EXPR &expr) -> common::Constify<Constant<T>, EXPR> * {
   if (auto *c{UnwrapExpr<Constant<T>>(expr)}) {
     return c;
   } else {
-    if constexpr (!std::is_same_v<T, SomeDerived>) {
-      if (auto *parens{UnwrapExpr<Parentheses<T>>(expr)}) {
-        return UnwrapConstantValue<T>(parens->left());
-      }
+    if (auto *parens{UnwrapExpr<Parentheses<T>>(expr)}) {
+      return UnwrapConstantValue<T>(parens->left());
     }
     return nullptr;
   }

diff  --git a/flang/lib/Evaluate/fold-complex.cpp b/flang/lib/Evaluate/fold-complex.cpp
index 3cd7c8490c582..efdb18f889132 100644
--- a/flang/lib/Evaluate/fold-complex.cpp
+++ b/flang/lib/Evaluate/fold-complex.cpp
@@ -62,6 +62,8 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
                     ToReal<KIND>(context, std::move(im))}});
       }
     }
+  } else if (name == "dot_product") {
+    return FoldDotProduct<T>(context, std::move(funcRef));
   } else if (name == "merge") {
     return FoldMerge<T>(context, std::move(funcRef));
   } else if (name == "product") {
@@ -70,7 +72,7 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
   } else if (name == "sum") {
     return FoldSum<T>(context, std::move(funcRef));
   }
-  // TODO: dot_product, matmul
+  // TODO: matmul
   return Expr<T>{std::move(funcRef)};
 }
 

diff  --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index 0a69e2c61b3d5..9bb31a0165825 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -552,6 +552,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
   } else if (name == "dim") {
     return FoldElementalIntrinsic<T, T, T>(
         context, std::move(funcRef), &Scalar<T>::DIM);
+  } else if (name == "dot_product") {
+    return FoldDotProduct<T>(context, std::move(funcRef));
   } else if (name == "dshiftl" || name == "dshiftr") {
     const auto fptr{
         name == "dshiftl" ? &Scalar<T>::DSHIFTL : &Scalar<T>::DSHIFTR};

diff  --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp
index bcf59a5d12136..052fe62bbd5dc 100644
--- a/flang/lib/Evaluate/fold-logical.cpp
+++ b/flang/lib/Evaluate/fold-logical.cpp
@@ -140,6 +140,8 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
           },
           ix->u);
     }
+  } else if (name == "dot_product") {
+    return FoldDotProduct<T>(context, std::move(funcRef));
   } else if (name == "extends_type_of") {
     // Type extension testing with EXTENDS_TYPE_OF() ignores any type
     // parameters. Returns a constant truth value when the result is known now.
@@ -231,7 +233,7 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
       name == "__builtin_ieee_support_underflow_control") {
     return Expr<T>{true};
   }
-  // TODO: dot_product, is_iostat_end,
+  // TODO: is_iostat_end,
   // is_iostat_eor, logical, matmul, out_of_range,
   // parity
   return Expr<T>{std::move(funcRef)};

diff  --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 6bdc922d40b7d..59b7637ae9947 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -136,6 +136,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
             [](const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> {
               return x.DIM(y).value;
             }));
+  } else if (name == "dot_product") {
+    return FoldDotProduct<T>(context, std::move(funcRef));
   } else if (name == "dprod") {
     if (auto scalars{GetScalarConstantArguments<T, T>(context, args)}) {
       return Fold(context,

diff  --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index 8043212820f59..89b5141b2f130 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-// TODO: DOT_PRODUCT, NORM2, PARITY
+// TODO: NORM2, PARITY
 
 #ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
 #define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
@@ -15,10 +15,96 @@
 
 namespace Fortran::evaluate {
 
-// Fold and validate a DIM= argument.  Returns true (with &dim empty)
-// when DIM= is not present or (with &dim set) when DIM= is present, constant,
-// and valid.  Returns false, possibly with an error message, when
-// DIM= is present but either not constant or not valid.
+// DOT_PRODUCT
+template <typename T>
+static Expr<T> FoldDotProduct(
+    FoldingContext &context, FunctionRef<T> &&funcRef) {
+  using Element = typename Constant<T>::Element;
+  auto args{funcRef.arguments()};
+  CHECK(args.size() == 2);
+  Folder<T> folder{context};
+  Constant<T> *va{folder.Folding(args[0])};
+  Constant<T> *vb{folder.Folding(args[1])};
+  if (va && vb) {
+    CHECK(va->Rank() == 1 && vb->Rank() == 1);
+    if (va->size() != vb->size()) {
+      context.messages().Say(
+          "Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US,
+          va->size(), vb->size());
+      return MakeInvalidIntrinsic(std::move(funcRef));
+    }
+    Element sum{};
+    bool overflow{false};
+    if constexpr (T::category == TypeCategory::Complex) {
+      std::vector<Element> conjugates;
+      for (const Element &x : va->values()) {
+        conjugates.emplace_back(x.CONJG());
+      }
+      Constant<T> conjgA{
+          std::move(conjugates), ConstantSubscripts{va->shape()}};
+      Expr<T> products{Fold(
+          context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
+      Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
+      Element correction; // Use Kahan summation for greater precision.
+      const auto &rounding{context.targetCharacteristics().roundingMode()};
+      for (const Element &x : cProducts.values()) {
+        auto next{correction.Add(x, rounding)};
+        overflow |= next.flags.test(RealFlag::Overflow);
+        auto added{sum.Add(next.value, rounding)};
+        overflow |= added.flags.test(RealFlag::Overflow);
+        correction = added.value.Subtract(sum, rounding)
+                         .value.Subtract(next.value, rounding)
+                         .value;
+        sum = std::move(added.value);
+      }
+    } else if constexpr (T::category == TypeCategory::Logical) {
+      Expr<T> conjunctions{Fold(context,
+          Expr<T>{LogicalOperation<T::kind>{LogicalOperator::And,
+              Expr<T>{Constant<T>{*va}}, Expr<T>{Constant<T>{*vb}}}})};
+      Constant<T> &cConjunctions{DEREF(UnwrapConstantValue<T>(conjunctions))};
+      for (const Element &x : cConjunctions.values()) {
+        if (x.IsTrue()) {
+          sum = Element{true};
+          break;
+        }
+      }
+    } else if constexpr (T::category == TypeCategory::Integer) {
+      Expr<T> products{
+          Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
+      Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
+      for (const Element &x : cProducts.values()) {
+        auto next{sum.AddSigned(x)};
+        overflow |= next.overflow;
+        sum = std::move(next.value);
+      }
+    } else { // T::category == TypeCategory::Real
+      Expr<T> products{
+          Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
+      Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
+      Element correction; // Use Kahan summation for greater precision.
+      const auto &rounding{context.targetCharacteristics().roundingMode()};
+      for (const Element &x : cProducts.values()) {
+        auto next{correction.Add(x, rounding)};
+        overflow |= next.flags.test(RealFlag::Overflow);
+        auto added{sum.Add(next.value, rounding)};
+        overflow |= added.flags.test(RealFlag::Overflow);
+        correction = added.value.Subtract(sum, rounding)
+                         .value.Subtract(next.value, rounding)
+                         .value;
+        sum = std::move(added.value);
+      }
+    }
+    if (overflow) {
+      context.messages().Say(
+          "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US,
+          T::AsFortran());
+    }
+    return Expr<T>{Constant<T>{std::move(sum)}};
+  }
+  return Expr<T>{std::move(funcRef)};
+}
+
+// Fold and validate a DIM= argument.  Returns false on error.
 bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &,
     ActualArguments &, std::optional<int> dimIndex, int rank);
 
@@ -203,13 +289,15 @@ static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
         overflow |= sum.overflow;
         element = sum.value;
       } else { // Real & Complex: use Kahan summation
-        auto next{array->At(at).Add(correction)};
+        const auto &rounding{context.targetCharacteristics().roundingMode()};
+        auto next{array->At(at).Add(correction, rounding)};
         overflow |= next.flags.test(RealFlag::Overflow);
-        auto sum{element.Add(next.value)};
+        auto sum{element.Add(next.value, rounding)};
         overflow |= sum.flags.test(RealFlag::Overflow);
         // correction = (sum - element) - next; algebraically zero
-        correction =
-            sum.value.Subtract(element).value.Subtract(next.value).value;
+        correction = sum.value.Subtract(element, rounding)
+                         .value.Subtract(next.value, rounding)
+                         .value;
         element = sum.value;
       }
     }};

diff  --git a/flang/test/Evaluate/fold-dot.f90 b/flang/test/Evaluate/fold-dot.f90
new file mode 100644
index 0000000000000..fb1a878ecd353
--- /dev/null
+++ b/flang/test/Evaluate/fold-dot.f90
@@ -0,0 +1,10 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of DOT_PRODUCT()
+module m
+  logical, parameter :: test_i4a = dot_product([(j,j=1,10)],[(j,j=1,10)]) == sum([(j*j,j=1,10)])
+  logical, parameter :: test_r4a = dot_product([(1.*j,j=1,10)],[(j,j=1,10)]) == sum([(j*j,j=1,10)])
+  logical, parameter :: test_z4a = dot_product([((j,j),j=1,10)],[((j,j),j=1,10)]) == sum([(((j,-j)*(j,j)),j=1,10)])
+  logical, parameter :: test_l4a = .not. dot_product([logical::],[logical::])
+  logical, parameter :: test_l4b = .not. dot_product([(j==2,j=1,10)], [(j==3,j=1,10)])
+  logical, parameter :: test_l4c = dot_product([(j==4,j=1,10)], [(j==4,j=1,10)])
+end


        


More information about the flang-commits mailing list