[flang-commits] [flang] e723c69 - [flang] Fold DOT_PRODUCT()
Peter Klausler via flang-commits
flang-commits at lists.llvm.org
Thu Aug 25 16:52:41 PDT 2022
Author: Peter Klausler
Date: 2022-08-25T16:52:21-07:00
New Revision: e723c69b94b9ac9c6977c0df011ee6219e67da4d
URL: https://github.com/llvm/llvm-project/commit/e723c69b94b9ac9c6977c0df011ee6219e67da4d
DIFF: https://github.com/llvm/llvm-project/commit/e723c69b94b9ac9c6977c0df011ee6219e67da4d.diff
LOG: [flang] Fold DOT_PRODUCT()
Implement constant folding of the intrinsic function DOT_PRODUCT().
Differential Revision: https://reviews.llvm.org/D132688
Added:
flang/test/Evaluate/fold-dot.f90
Modified:
flang/include/flang/Evaluate/fold.h
flang/lib/Evaluate/fold-complex.cpp
flang/lib/Evaluate/fold-integer.cpp
flang/lib/Evaluate/fold-logical.cpp
flang/lib/Evaluate/fold-real.cpp
flang/lib/Evaluate/fold-reduction.h
Removed:
################################################################################
diff --git a/flang/include/flang/Evaluate/fold.h b/flang/include/flang/Evaluate/fold.h
index e7081a06dddb2..24fb54761962c 100644
--- a/flang/include/flang/Evaluate/fold.h
+++ b/flang/include/flang/Evaluate/fold.h
@@ -57,10 +57,8 @@ auto UnwrapConstantValue(EXPR &expr) -> common::Constify<Constant<T>, EXPR> * {
if (auto *c{UnwrapExpr<Constant<T>>(expr)}) {
return c;
} else {
- if constexpr (!std::is_same_v<T, SomeDerived>) {
- if (auto *parens{UnwrapExpr<Parentheses<T>>(expr)}) {
- return UnwrapConstantValue<T>(parens->left());
- }
+ if (auto *parens{UnwrapExpr<Parentheses<T>>(expr)}) {
+ return UnwrapConstantValue<T>(parens->left());
}
return nullptr;
}
diff --git a/flang/lib/Evaluate/fold-complex.cpp b/flang/lib/Evaluate/fold-complex.cpp
index 3cd7c8490c582..efdb18f889132 100644
--- a/flang/lib/Evaluate/fold-complex.cpp
+++ b/flang/lib/Evaluate/fold-complex.cpp
@@ -62,6 +62,8 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
ToReal<KIND>(context, std::move(im))}});
}
}
+ } else if (name == "dot_product") {
+ return FoldDotProduct<T>(context, std::move(funcRef));
} else if (name == "merge") {
return FoldMerge<T>(context, std::move(funcRef));
} else if (name == "product") {
@@ -70,7 +72,7 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
} else if (name == "sum") {
return FoldSum<T>(context, std::move(funcRef));
}
- // TODO: dot_product, matmul
+ // TODO: matmul
return Expr<T>{std::move(funcRef)};
}
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index 0a69e2c61b3d5..9bb31a0165825 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -552,6 +552,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
} else if (name == "dim") {
return FoldElementalIntrinsic<T, T, T>(
context, std::move(funcRef), &Scalar<T>::DIM);
+ } else if (name == "dot_product") {
+ return FoldDotProduct<T>(context, std::move(funcRef));
} else if (name == "dshiftl" || name == "dshiftr") {
const auto fptr{
name == "dshiftl" ? &Scalar<T>::DSHIFTL : &Scalar<T>::DSHIFTR};
diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp
index bcf59a5d12136..052fe62bbd5dc 100644
--- a/flang/lib/Evaluate/fold-logical.cpp
+++ b/flang/lib/Evaluate/fold-logical.cpp
@@ -140,6 +140,8 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
},
ix->u);
}
+ } else if (name == "dot_product") {
+ return FoldDotProduct<T>(context, std::move(funcRef));
} else if (name == "extends_type_of") {
// Type extension testing with EXTENDS_TYPE_OF() ignores any type
// parameters. Returns a constant truth value when the result is known now.
@@ -231,7 +233,7 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
name == "__builtin_ieee_support_underflow_control") {
return Expr<T>{true};
}
- // TODO: dot_product, is_iostat_end,
+ // TODO: is_iostat_end,
// is_iostat_eor, logical, matmul, out_of_range,
// parity
return Expr<T>{std::move(funcRef)};
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 6bdc922d40b7d..59b7637ae9947 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -136,6 +136,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
[](const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> {
return x.DIM(y).value;
}));
+ } else if (name == "dot_product") {
+ return FoldDotProduct<T>(context, std::move(funcRef));
} else if (name == "dprod") {
if (auto scalars{GetScalarConstantArguments<T, T>(context, args)}) {
return Fold(context,
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index 8043212820f59..89b5141b2f130 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-// TODO: DOT_PRODUCT, NORM2, PARITY
+// TODO: NORM2, PARITY
#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
@@ -15,10 +15,96 @@
namespace Fortran::evaluate {
-// Fold and validate a DIM= argument. Returns true (with &dim empty)
-// when DIM= is not present or (with &dim set) when DIM= is present, constant,
-// and valid. Returns false, possibly with an error message, when
-// DIM= is present but either not constant or not valid.
+// DOT_PRODUCT
+template <typename T>
+static Expr<T> FoldDotProduct(
+ FoldingContext &context, FunctionRef<T> &&funcRef) {
+ using Element = typename Constant<T>::Element;
+ auto args{funcRef.arguments()};
+ CHECK(args.size() == 2);
+ Folder<T> folder{context};
+ Constant<T> *va{folder.Folding(args[0])};
+ Constant<T> *vb{folder.Folding(args[1])};
+ if (va && vb) {
+ CHECK(va->Rank() == 1 && vb->Rank() == 1);
+ if (va->size() != vb->size()) {
+ context.messages().Say(
+ "Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US,
+ va->size(), vb->size());
+ return MakeInvalidIntrinsic(std::move(funcRef));
+ }
+ Element sum{};
+ bool overflow{false};
+ if constexpr (T::category == TypeCategory::Complex) {
+ std::vector<Element> conjugates;
+ for (const Element &x : va->values()) {
+ conjugates.emplace_back(x.CONJG());
+ }
+ Constant<T> conjgA{
+ std::move(conjugates), ConstantSubscripts{va->shape()}};
+ Expr<T> products{Fold(
+ context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
+ Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
+ Element correction; // Use Kahan summation for greater precision.
+ const auto &rounding{context.targetCharacteristics().roundingMode()};
+ for (const Element &x : cProducts.values()) {
+ auto next{correction.Add(x, rounding)};
+ overflow |= next.flags.test(RealFlag::Overflow);
+ auto added{sum.Add(next.value, rounding)};
+ overflow |= added.flags.test(RealFlag::Overflow);
+ correction = added.value.Subtract(sum, rounding)
+ .value.Subtract(next.value, rounding)
+ .value;
+ sum = std::move(added.value);
+ }
+ } else if constexpr (T::category == TypeCategory::Logical) {
+ Expr<T> conjunctions{Fold(context,
+ Expr<T>{LogicalOperation<T::kind>{LogicalOperator::And,
+ Expr<T>{Constant<T>{*va}}, Expr<T>{Constant<T>{*vb}}}})};
+ Constant<T> &cConjunctions{DEREF(UnwrapConstantValue<T>(conjunctions))};
+ for (const Element &x : cConjunctions.values()) {
+ if (x.IsTrue()) {
+ sum = Element{true};
+ break;
+ }
+ }
+ } else if constexpr (T::category == TypeCategory::Integer) {
+ Expr<T> products{
+ Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
+ Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
+ for (const Element &x : cProducts.values()) {
+ auto next{sum.AddSigned(x)};
+ overflow |= next.overflow;
+ sum = std::move(next.value);
+ }
+ } else { // T::category == TypeCategory::Real
+ Expr<T> products{
+ Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
+ Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
+ Element correction; // Use Kahan summation for greater precision.
+ const auto &rounding{context.targetCharacteristics().roundingMode()};
+ for (const Element &x : cProducts.values()) {
+ auto next{correction.Add(x, rounding)};
+ overflow |= next.flags.test(RealFlag::Overflow);
+ auto added{sum.Add(next.value, rounding)};
+ overflow |= added.flags.test(RealFlag::Overflow);
+ correction = added.value.Subtract(sum, rounding)
+ .value.Subtract(next.value, rounding)
+ .value;
+ sum = std::move(added.value);
+ }
+ }
+ if (overflow) {
+ context.messages().Say(
+ "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US,
+ T::AsFortran());
+ }
+ return Expr<T>{Constant<T>{std::move(sum)}};
+ }
+ return Expr<T>{std::move(funcRef)};
+}
+
+// Fold and validate a DIM= argument. Returns false on error.
bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &,
ActualArguments &, std::optional<int> dimIndex, int rank);
@@ -203,13 +289,15 @@ static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
overflow |= sum.overflow;
element = sum.value;
} else { // Real & Complex: use Kahan summation
- auto next{array->At(at).Add(correction)};
+ const auto &rounding{context.targetCharacteristics().roundingMode()};
+ auto next{array->At(at).Add(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
- auto sum{element.Add(next.value)};
+ auto sum{element.Add(next.value, rounding)};
overflow |= sum.flags.test(RealFlag::Overflow);
// correction = (sum - element) - next; algebraically zero
- correction =
- sum.value.Subtract(element).value.Subtract(next.value).value;
+ correction = sum.value.Subtract(element, rounding)
+ .value.Subtract(next.value, rounding)
+ .value;
element = sum.value;
}
}};
diff --git a/flang/test/Evaluate/fold-dot.f90 b/flang/test/Evaluate/fold-dot.f90
new file mode 100644
index 0000000000000..fb1a878ecd353
--- /dev/null
+++ b/flang/test/Evaluate/fold-dot.f90
@@ -0,0 +1,10 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of DOT_PRODUCT()
+module m
+ logical, parameter :: test_i4a = dot_product([(j,j=1,10)],[(j,j=1,10)]) == sum([(j*j,j=1,10)])
+ logical, parameter :: test_r4a = dot_product([(1.*j,j=1,10)],[(j,j=1,10)]) == sum([(j*j,j=1,10)])
+ logical, parameter :: test_z4a = dot_product([((j,j),j=1,10)],[((j,j),j=1,10)]) == sum([(((j,-j)*(j,j)),j=1,10)])
+ logical, parameter :: test_l4a = .not. dot_product([logical::],[logical::])
+ logical, parameter :: test_l4b = .not. dot_product([(j==2,j=1,10)], [(j==3,j=1,10)])
+ logical, parameter :: test_l4c = dot_product([(j==4,j=1,10)], [(j==4,j=1,10)])
+end
More information about the flang-commits
mailing list