[flang-commits] [flang] 3f59474 - [flang] Fix implementation of Kahan summation (#116897)
via flang-commits
flang-commits at lists.llvm.org
Thu Nov 21 10:47:25 PST 2024
Author: Peter Klausler
Date: 2024-11-21T10:47:21-08:00
New Revision: 3f594741cf8e1537fb25f84ef3cf2245b08d8089
URL: https://github.com/llvm/llvm-project/commit/3f594741cf8e1537fb25f84ef3cf2245b08d8089
DIFF: https://github.com/llvm/llvm-project/commit/3f594741cf8e1537fb25f84ef3cf2245b08d8089.diff
LOG: [flang] Fix implementation of Kahan summation (#116897)
In the runtime's implementation of floating-point SUM, the
implementation of Kahan's algorithm for increased precision is
incorrect. The running correction factor should be subtracted from each
new data item, not added to it. This fix ensures that the sum of 100M
random default real values between 0. and 1. is close to 5.E7.
See https://en.wikipedia.org/wiki/Kahan_summation_algorithm.
Added:
Modified:
flang/lib/Evaluate/fold-matmul.h
flang/lib/Evaluate/fold-real.cpp
flang/lib/Evaluate/fold-reduction.h
flang/runtime/sum.cpp
Removed:
################################################################################
diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h
index be9c547d45286c..c3d65a90409098 100644
--- a/flang/lib/Evaluate/fold-matmul.h
+++ b/flang/lib/Evaluate/fold-matmul.h
@@ -61,7 +61,7 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
auto product{aElt.Multiply(bElt)};
overflow |= product.flags.test(RealFlag::Overflow);
if constexpr (useKahanSummation) {
- auto next{correction.Add(product.value, rounding)};
+ auto next{product.value.Subtract(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 0b79a417942a45..6fb5249c8a5e2e 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -78,7 +78,7 @@ template <int KIND> class Norm2Accumulator {
auto scaled{item.Divide(scale).value};
auto square{scaled.Multiply(scaled).value};
if constexpr (useKahanSummation) {
- auto next{square.Add(correction_, rounding_)};
+ auto next{square.Subtract(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index 8ca0794ab0fc7c..b1b81d8740d3f3 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -47,7 +47,7 @@ static Expr<T> FoldDotProduct(
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
- auto next{correction.Add(x, rounding)};
+ auto next{x.Subtract(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
@@ -90,7 +90,7 @@ static Expr<T> FoldDotProduct(
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
- auto next{correction.Add(x, rounding)};
+ auto next{x.Subtract(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
@@ -348,7 +348,7 @@ template <typename T> class SumAccumulator {
overflow_ |= sum.overflow;
element = sum.value;
} else { // Real & Complex: use Kahan summation
- auto next{array_.At(at).Add(correction_, rounding_)};
+ auto next{array_.At(at).Subtract(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
diff --git a/flang/runtime/sum.cpp b/flang/runtime/sum.cpp
index 04241443275eb9..10b81242546521 100644
--- a/flang/runtime/sum.cpp
+++ b/flang/runtime/sum.cpp
@@ -53,7 +53,7 @@ template <typename INTERMEDIATE> class RealSumAccumulator {
}
template <typename A> RT_API_ATTRS bool Accumulate(A x) {
// Kahan summation
- auto next{x + correction_};
+ auto next{x - correction_};
auto oldSum{sum_};
sum_ += next;
correction_ = (sum_ - oldSum) - next; // algebraically zero
More information about the flang-commits
mailing list