[flang-commits] [flang] [flang] Fix implementation of Kahan summation (PR #116897)
via flang-commits
flang-commits at lists.llvm.org
Tue Nov 19 16:53:16 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-runtime
Author: Peter Klausler (klausler)
<details>
<summary>Changes</summary>
In the runtime's implementation of floating-point SUM, the implementation of Kahan's algorithm for increased precision is incorrect. The running correction factor should be subtracted from each new data item, not added to it. This fix ensures that the sum of 100M random default real values between 0. and 1. is close to 5.E7.
See https://en.wikipedia.org/wiki/Kahan_summation_algorithm.
---
Full diff: https://github.com/llvm/llvm-project/pull/116897.diff
4 Files Affected:
- (modified) flang/lib/Evaluate/fold-matmul.h (+1-1)
- (modified) flang/lib/Evaluate/fold-real.cpp (+1-1)
- (modified) flang/lib/Evaluate/fold-reduction.h (+3-3)
- (modified) flang/runtime/sum.cpp (+1-1)
``````````diff
diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h
index be9c547d45286c..c3d65a90409098 100644
--- a/flang/lib/Evaluate/fold-matmul.h
+++ b/flang/lib/Evaluate/fold-matmul.h
@@ -61,7 +61,7 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
auto product{aElt.Multiply(bElt)};
overflow |= product.flags.test(RealFlag::Overflow);
if constexpr (useKahanSummation) {
- auto next{correction.Add(product.value, rounding)};
+ auto next{product.value.Subtract(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 0b79a417942a45..6fb5249c8a5e2e 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -78,7 +78,7 @@ template <int KIND> class Norm2Accumulator {
auto scaled{item.Divide(scale).value};
auto square{scaled.Multiply(scaled).value};
if constexpr (useKahanSummation) {
- auto next{square.Add(correction_, rounding_)};
+ auto next{square.Subtract(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index 8ca0794ab0fc7c..b1b81d8740d3f3 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -47,7 +47,7 @@ static Expr<T> FoldDotProduct(
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
- auto next{correction.Add(x, rounding)};
+ auto next{x.Subtract(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
@@ -90,7 +90,7 @@ static Expr<T> FoldDotProduct(
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
- auto next{correction.Add(x, rounding)};
+ auto next{x.Subtract(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
@@ -348,7 +348,7 @@ template <typename T> class SumAccumulator {
overflow_ |= sum.overflow;
element = sum.value;
} else { // Real & Complex: use Kahan summation
- auto next{array_.At(at).Add(correction_, rounding_)};
+ auto next{array_.At(at).Subtract(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
diff --git a/flang/runtime/sum.cpp b/flang/runtime/sum.cpp
index 04241443275eb9..10b81242546521 100644
--- a/flang/runtime/sum.cpp
+++ b/flang/runtime/sum.cpp
@@ -53,7 +53,7 @@ template <typename INTERMEDIATE> class RealSumAccumulator {
}
template <typename A> RT_API_ATTRS bool Accumulate(A x) {
// Kahan summation
- auto next{x + correction_};
+ auto next{x - correction_};
auto oldSum{sum_};
sum_ += next;
correction_ = (sum_ - oldSum) - next; // algebraically zero
``````````
</details>
https://github.com/llvm/llvm-project/pull/116897
More information about the flang-commits
mailing list