[flang-commits] [flang] [flang] Fix implementation of Kahan summation (PR #116897)
Peter Klausler via flang-commits
flang-commits at lists.llvm.org
Tue Nov 19 16:52:41 PST 2024
https://github.com/klausler created https://github.com/llvm/llvm-project/pull/116897
In the runtime's implementation of floating-point SUM, the implementation of Kahan's algorithm for increased precision is incorrect. The running correction factor should be subtracted from each new data item, not added to it. This fix ensures that the sum of 100M random default real values between 0. and 1. is close to 5.E7.
See https://en.wikipedia.org/wiki/Kahan_summation_algorithm.
>From a91765d1c39f2289f913377fe4649a4353a8abd2 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Tue, 19 Nov 2024 16:46:39 -0800
Subject: [PATCH] [flang] Fix implementation of Kahan summation
In the runtime's implementation of floating-point SUM, the implementation
of Kahan's algorithm for increased precision is incorrect. The
running correction factor should be subtracted from each new data item,
not added to it. This fix ensures that the sum of 100M random default
real values between 0. and 1. is close to 5.E7.
See https://en.wikipedia.org/wiki/Kahan_summation_algorithm.
---
flang/lib/Evaluate/fold-matmul.h | 2 +-
flang/lib/Evaluate/fold-real.cpp | 2 +-
flang/lib/Evaluate/fold-reduction.h | 6 +++---
flang/runtime/sum.cpp | 2 +-
4 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h
index be9c547d45286c..c3d65a90409098 100644
--- a/flang/lib/Evaluate/fold-matmul.h
+++ b/flang/lib/Evaluate/fold-matmul.h
@@ -61,7 +61,7 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
auto product{aElt.Multiply(bElt)};
overflow |= product.flags.test(RealFlag::Overflow);
if constexpr (useKahanSummation) {
- auto next{correction.Add(product.value, rounding)};
+ auto next{product.value.Subtract(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 0b79a417942a45..6fb5249c8a5e2e 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -78,7 +78,7 @@ template <int KIND> class Norm2Accumulator {
auto scaled{item.Divide(scale).value};
auto square{scaled.Multiply(scaled).value};
if constexpr (useKahanSummation) {
- auto next{square.Add(correction_, rounding_)};
+ auto next{square.Subtract(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index 8ca0794ab0fc7c..b1b81d8740d3f3 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -47,7 +47,7 @@ static Expr<T> FoldDotProduct(
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
- auto next{correction.Add(x, rounding)};
+ auto next{x.Subtract(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
@@ -90,7 +90,7 @@ static Expr<T> FoldDotProduct(
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
- auto next{correction.Add(x, rounding)};
+ auto next{x.Subtract(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
@@ -348,7 +348,7 @@ template <typename T> class SumAccumulator {
overflow_ |= sum.overflow;
element = sum.value;
} else { // Real & Complex: use Kahan summation
- auto next{array_.At(at).Add(correction_, rounding_)};
+ auto next{array_.At(at).Subtract(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
diff --git a/flang/runtime/sum.cpp b/flang/runtime/sum.cpp
index 04241443275eb9..10b81242546521 100644
--- a/flang/runtime/sum.cpp
+++ b/flang/runtime/sum.cpp
@@ -53,7 +53,7 @@ template <typename INTERMEDIATE> class RealSumAccumulator {
}
template <typename A> RT_API_ATTRS bool Accumulate(A x) {
// Kahan summation
- auto next{x + correction_};
+ auto next{x - correction_};
auto oldSum{sum_};
sum_ += next;
correction_ = (sum_ - oldSum) - next; // algebraically zero
More information about the flang-commits
mailing list