[flang-commits] [flang] [flang] Fold NORM2() (PR #66240)
via flang-commits
flang-commits at lists.llvm.org
Wed Sep 13 09:55:16 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-semantics
<details>
<summary>Changes</summary>
Fold references to the (relatively new) intrinsic function NORM2 at compilation time when the argument(s) are all constants. (Getting this done right involved some changes to the API of the accumulator function objects used by the DoReduction<> template, which rippled through some other reduction function folding code.)
--
Patch is 20.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66240.diff
5 Files Affected:
- (modified) flang/lib/Evaluate/fold-integer.cpp (+23-13)
- (modified) flang/lib/Evaluate/fold-logical.cpp (+1-4)
- (modified) flang/lib/Evaluate/fold-real.cpp (+77-1)
- (modified) flang/lib/Evaluate/fold-reduction.h (+118-50)
- (added) flang/test/Evaluate/fold-norm2.f90 (+29)
<pre>
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index fe38c81d976822d..dedfc20a491cd88 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -264,6 +264,26 @@ Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
}
// COUNT()
+template <typename T, int MASK_KIND> class CountAccumulator {
+ using MaskT = Type<TypeCategory::Logical, MASK_KIND>;
+
+public:
+ CountAccumulator(const Constant<MaskT> &mask) : mask_{mask} {}
+ void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
+ if (mask_.At(at).IsTrue()) {
+ auto incremented{element.AddSigned(Scalar<T>{1})};
+ overflow_ |= incremented.overflow;
+ element = incremented.value;
+ }
+ }
+ bool overflow() const { return overflow_; }
+ void Done(Scalar<T> &) const {}
+
+private:
+ const Constant<MaskT> &mask_;
+ bool overflow_{false};
+};
+
template <typename T, int maskKind>
static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
using LogicalResult = Type<TypeCategory::Logical, maskKind>;
@@ -274,17 +294,9 @@ static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
: Folder<LogicalResult>{context}.Folding(arg[0])}) {
std::optional<int> dim;
if (CheckReductionDIM(dim, context, arg, 1, mask->Rank())) {
- bool overflow{false};
- auto accumulator{
- [&mask, &overflow](Scalar<T> &element, const ConstantSubscripts &at) {
- if (mask->At(at).IsTrue()) {
- auto incremented{element.AddSigned(Scalar<T>{1})};
- overflow |= incremented.overflow;
- element = incremented.value;
- }
- }};
+ CountAccumulator<T, maskKind> accumulator{*mask};
Constant<T> result{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
- if (overflow) {
+ if (accumulator.overflow()) {
context.messages().Say(
"Result of intrinsic function COUNT overflows its result type"_warn_en_US);
}
@@ -513,9 +525,7 @@ static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
- auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
- element = (element.*operation)(array->At(at));
- }};
+ OperationAccumulator<T> accumulator{*array, operation};
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
}
return Expr<T>{std::move(ref)};
diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp
index 95335f7f48bbedf..9fc42adf805f468 100644
--- a/flang/lib/Evaluate/fold-logical.cpp
+++ b/flang/lib/Evaluate/fold-logical.cpp
@@ -28,14 +28,11 @@ static Expr<T> FoldAllAnyParity(FoldingContext &context, FunctionRef<T> &&ref,
Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
Scalar<T> identity) {
static_assert(T::category == TypeCategory::Logical);
- using Element = Scalar<T>;
std::optional<int> dim;
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
- auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
- element = (element.*operation)(array->At(at));
- }};
+ OperationAccumulator accumulator{*array, operation};
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
}
return Expr<T>{std::move(ref)};
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 671d897ef7b2f82..8e3ab1d8fd30b09 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -43,6 +43,80 @@ static Expr<T> FoldTransformationalBessel(
return Expr<T>{std::move(funcRef)};
}
+// NORM2
+template <int KIND> class Norm2Accumulator {
+ using T = Type<TypeCategory::Real, KIND>;
+
+public:
+ Norm2Accumulator(
+ const Constant<T> &array, const Constant<T> &maxAbs, Rounding rounding)
+ : array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
+ void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
+ // Kahan summation of scaled elements
+ auto scale{maxAbs_.At(maxAbsAt_)};
+ if (scale.IsZero()) {
+ // If maxAbs is zero, so are all elements, and result
+ element = scale;
+ } else {
+ auto item{array_.At(at)};
+ auto scaled{item.Divide(scale).value};
+ auto square{item.Multiply(scaled).value};
+ auto next{square.Add(correction_, rounding_)};
+ overflow_ |= next.flags.test(RealFlag::Overflow);
+ auto sum{element.Add(next.value, rounding_)};
+ overflow_ |= sum.flags.test(RealFlag::Overflow);
+ correction_ = sum.value.Subtract(element, rounding_)
+ .value.Subtract(next.value, rounding_)
+ .value;
+ element = sum.value;
+ }
+ }
+ bool overflow() const { return overflow_; }
+ void Done(Scalar<T> &result) {
+ auto corrected{result.Add(correction_, rounding_)};
+ overflow_ |= corrected.flags.test(RealFlag::Overflow);
+ correction_ = Scalar<T>{};
+ auto rescaled{corrected.value.Multiply(maxAbs_.At(maxAbsAt_))};
+ maxAbs_.IncrementSubscripts(maxAbsAt_);
+ overflow_ |= rescaled.flags.test(RealFlag::Overflow);
+ result = rescaled.value.SQRT().value;
+ }
+
+private:
+ const Constant<T> &array_;
+ const Constant<T> &maxAbs_;
+ const Rounding rounding_;
+ bool overflow_{false};
+ Scalar<T> correction_{};
+ ConstantSubscripts maxAbsAt_{maxAbs_.lbounds()};
+};
+
+template <int KIND>
+static Expr<Type<TypeCategory::Real, KIND>> FoldNorm2(FoldingContext &context,
+ FunctionRef<Type<TypeCategory::Real, KIND>> &&funcRef) {
+ using T = Type<TypeCategory::Real, KIND>;
+ using Element = typename Constant<T>::Element;
+ std::optional<int> dim;
+ const Element identity{};
+ if (std::optional<Constant<T>> array{
+ ProcessReductionArgs<T>(context, funcRef.arguments(), dim, identity,
+ /*X=*/0, /*DIM=*/1)}) {
+ MaxvalMinvalAccumulator<T, /*ABS=*/true> maxAbsAccumulator{
+ RelationalOperator::GT, context, *array};
+ Constant<T> maxAbs{
+ DoReduction<T>(*array, dim, identity, maxAbsAccumulator)};
+ Norm2Accumulator norm2Accumulator{
+ *array, maxAbs, context.targetCharacteristics().roundingMode()};
+ Constant<T> result{DoReduction<T>(*array, dim, identity, norm2Accumulator)};
+ if (norm2Accumulator.overflow()) {
+ context.messages().Say(
+ "NORM2() of REAL(%d) data overflowed"_warn_en_US, KIND);
+ }
+ return Expr<T>{std::move(result)};
+ }
+ return Expr<T>{std::move(funcRef)};
+}
+
template <int KIND>
Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
FoldingContext &context,
@@ -238,6 +312,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
},
sExpr->u);
}
+ } else if (name == "norm2") {
+ return FoldNorm2<T::kind>(context, std::move(funcRef));
} else if (name == "product") {
auto one{Scalar<T>::FromInteger(value::Integer<8>{1}).value};
return FoldProduct<T>(context, std::move(funcRef), one);
@@ -354,7 +430,7 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
return result.value;
}));
}
- // TODO: dot_product, matmul, norm2
+ // TODO: matmul
return Expr<T>{std::move(funcRef)};
}
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index b76cecffaf1c639..cff7f54c60d91ba 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -6,8 +6,6 @@
//
//===----------------------------------------------------------------------===//
-// TODO: NORM2, PARITY
-
#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
@@ -77,7 +75,8 @@ static Expr<T> FoldDotProduct(
overflow |= next.overflow;
sum = std::move(next.value);
}
- } else { // T::category == TypeCategory::Real
+ } else {
+ static_assert(T::category == TypeCategory::Real);
Expr<T> products{
Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
@@ -172,7 +171,8 @@ static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
}
// Generalized reduction to an array of one dimension fewer (w/ DIM=)
-// or to a scalar (w/o DIM=).
+// or to a scalar (w/o DIM=). The ACCUMULATOR type must define
+// operator()(Scalar<T> &, const ConstantSubscripts &) and Done(Scalar<T> &).
template <typename T, typename ACCUMULATOR, typename ARRAY>
static Constant<T> DoReduction(const Constant<ARRAY> &array,
std::optional<int> &dim, const Scalar<T> &identity,
@@ -193,6 +193,7 @@ static Constant<T> DoReduction(const Constant<ARRAY> &array,
for (ConstantSubscript j{0}; j < dimExtent; ++j, ++dimAt) {
accumulator(elements.back(), at);
}
+ accumulator.Done(elements.back());
}
} else { // no DIM=, result is scalar
elements.push_back(identity);
@@ -200,6 +201,7 @@ static Constant<T> DoReduction(const Constant<ARRAY> &array,
IncrementSubscripts(at, array.shape())) {
accumulator(elements.back(), at);
}
+ accumulator.Done(elements.back());
}
if constexpr (T::category == TypeCategory::Character) {
return {static_cast<ConstantSubscript>(identity.size()),
@@ -210,58 +212,85 @@ static Constant<T> DoReduction(const Constant<ARRAY> &array,
}
// MAXVAL & MINVAL
+template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
+public:
+ MaxvalMinvalAccumulator(
+ RelationalOperator opr, FoldingContext &context, const Constant<T> &array)
+ : opr_{opr}, context_{context}, array_{array} {};
+ void operator()(Scalar<T> &element, const ConstantSubscripts &at) const {
+ auto aAt{array_.At(at)};
+ if constexpr (ABS) {
+ aAt = aAt.ABS();
+ }
+ Expr<LogicalResult> test{PackageRelation(
+ opr_, Expr<T>{Constant<T>{aAt}}, Expr<T>{Constant<T>{element}})};
+ auto folded{GetScalarConstantValue<LogicalResult>(
+ test.Rewrite(context_, std::move(test)))};
+ CHECK(folded.has_value());
+ if (folded->IsTrue()) {
+ element = array_.At(at);
+ }
+ }
+ void Done(Scalar<T> &) const {}
+
+private:
+ RelationalOperator opr_;
+ FoldingContext &context_;
+ const Constant<T> &array_;
+};
+
template <typename T>
static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
RelationalOperator opr, const Scalar<T> &identity) {
static_assert(T::category == TypeCategory::Integer ||
T::category == TypeCategory::Real ||
T::category == TypeCategory::Character);
- using Element = Scalar<T>;
std::optional<int> dim;
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
- auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
- Expr<LogicalResult> test{PackageRelation(opr,
- Expr<T>{Constant<T>{array->At(at)}}, Expr<T>{Constant<T>{element}})};
- auto folded{GetScalarConstantValue<LogicalResult>(
- test.Rewrite(context, std::move(test)))};
- CHECK(folded.has_value());
- if (folded->IsTrue()) {
- element = array->At(at);
- }
- }};
+ MaxvalMinvalAccumulator accumulator{opr, context, *array};
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
}
return Expr<T>{std::move(ref)};
}
// PRODUCT
+template <typename T> class ProductAccumulator {
+public:
+ ProductAccumulator(const Constant<T> &array) : array_{array} {}
+ void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
+ if constexpr (T::category == TypeCategory::Integer) {
+ auto prod{element.MultiplySigned(array_.At(at))};
+ overflow_ |= prod.SignedMultiplicationOverflowed();
+ element = prod.lower;
+ } else { // Real & Complex
+ auto prod{element.Multiply(array_.At(at))};
+ overflow_ |= prod.flags.test(RealFlag::Overflow);
+ element = prod.value;
+ }
+ }
+ bool overflow() const { return overflow_; }
+ void Done(Scalar<T> &) const {}
+
+private:
+ const Constant<T> &array_;
+ bool overflow_{false};
+};
+
template <typename T>
static Expr<T> FoldProduct(
FoldingContext &context, FunctionRef<T> &&ref, Scalar<T> identity) {
static_assert(T::category == TypeCategory::Integer ||
T::category == TypeCategory::Real ||
T::category == TypeCategory::Complex);
- using Element = typename Constant<T>::Element;
std::optional<int> dim;
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
- bool overflow{false};
- auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
- if constexpr (T::category == TypeCategory::Integer) {
- auto prod{element.MultiplySigned(array->At(at))};
- overflow |= prod.SignedMultiplicationOverflowed();
- element = prod.lower;
- } else { // Real & Complex
- auto prod{element.Multiply(array->At(at))};
- overflow |= prod.flags.test(RealFlag::Overflow);
- element = prod.value;
- }
- }};
+ ProductAccumulator accumulator{*array};
auto result{Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)}};
- if (overflow) {
+ if (accumulator.overflow()) {
context.messages().Say(
"PRODUCT() of %s data overflowed"_warn_en_US, T::AsFortran());
}
@@ -271,6 +300,46 @@ static Expr<T> FoldProduct(
}
// SUM
+template <typename T> class SumAccumulator {
+ using Element = typename Constant<T>::Element;
+
+public:
+ SumAccumulator(const Constant<T> &array, Rounding rounding)
+ : array_{array}, rounding_{rounding} {}
+ void operator()(Element &element, const ConstantSubscripts &at) {
+ if constexpr (T::category == TypeCategory::Integer) {
+ auto sum{element.AddSigned(array_.At(at))};
+ overflow_ |= sum.overflow;
+ element = sum.value;
+ } else { // Real & Complex: use Kahan summation
+ auto next{array_.At(at).Add(correction_, rounding_)};
+ overflow_ |= next.flags.test(RealFlag::Overflow);
+ auto sum{element.Add(next.value, rounding_)};
+ overflow_ |= sum.flags.test(RealFlag::Overflow);
+ // correction = (sum - element) - next; algebraically zero
+ correction_ = sum.value.Subtract(element, rounding_)
+ .value.Subtract(next.value, rounding_)
+ .value;
+ element = sum.value;
+ }
+ }
+ bool overflow() const { return overflow_; }
+ void Done([[maybe_unused]] Element &element) {
+ if constexpr (T::category != TypeCategory::Integer) {
+ auto corrected{element.Add(correction_, rounding_)};
+ overflow_ |= corrected.flags.test(RealFlag::Overflow);
+ correction_ = Scalar<T>{};
+ element = corrected.value;
+ }
+ }
+
+private:
+ const Constant<T> &array_;
+ Rounding rounding_;
+ bool overflow_{false};
+ Element correction_{};
+};
+
template <typename T>
static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
static_assert(T::category == TypeCategory::Integer ||
@@ -278,31 +347,14 @@ static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
T::category == TypeCategory::Complex);
using Element = typename Constant<T>::Element;
std::optional<int> dim;
- Element identity{}, correction{};
+ Element identity{};
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
- bool overflow{false};
- auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
- if constexpr (T::category == TypeCategory::Integer) {
- auto sum{element.AddSigned(array->At(at))};
- overflow |= sum.overflow;
- element = sum.value;
- } else { // Real & Complex: use Kahan summation
- const auto &rounding{context.targetCharacteristics().roundingMode()};
- auto next{array->At(at).Add(correction, rounding)};
- overflow |= next.flags.test(RealFlag::Overflow);
- auto sum{element.Add(next.value, rounding)};
- overflow |= sum.flags.test(RealFlag::Overflow);
- // correction = (sum - element) - next; algebraically zero
- correction = sum.value.Subtract(element, rounding)
- .value.Subtract(next.value, rounding)
- .value;
- element = sum.value;
- }
- }};
+ SumAccumulator accumulator{
+ *array, context.targetCharacteristics().roundingMode()};
auto result{Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)}};
- if (overflow) {
+ if (accumulator.overflow()) {
context.messages().Say(
"SUM() of %s data overflowed"_warn_en_US, T::AsFortran());
}
@@ -311,5 +363,21 @@ static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
return Expr<T>{std::move(ref)};
}
+// Utility for IALL, IANY, IPARITY, ALL, ANY, & PARITY
+template <typename T> class OperationAccumulator {
+public:
+ OperationAccumulator(const Constant<T> &array,
+ Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const)
+ : array_{array}, operation_{operation} {}
+ void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
+ element = (element.*operation_)(array_.At(at));
+ }
+ void Done(Scalar<T> &) const {}
+
+private:
+ const Constant<T> &array_;
+ Scalar<T> (Scalar<T>::*operation_)(const Scalar<T> &) const;
+};
+
} // namespace Fortran::evaluate
#endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_
diff --git a/flang/test/Evaluate/fold-norm2.f90 b/flang/test/Evaluate/fold-norm2.f90
new file mode 100644
index 000000000000000..30d5289b5a6e33c
--- /dev/null
+++ b/flang/test/Evaluate/fold-norm2.f90
@@ -0,0 +1,29 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of NORM2(), F'2023 16.9.153
+module m
+ ! Examples from the standard
+ logical, parameter :: test_ex1 = norm2([3.,4.]) == 5.
+ real, parameter :: ex2(2,2) = resha...
<truncated>
</pre>
</details>
https://github.com/llvm/llvm-project/pull/66240
More information about the flang-commits
mailing list