[flang-commits] [flang] 3f59474 - [flang] Fix implementation of Kahan summation (#116897)

via flang-commits flang-commits at lists.llvm.org
Thu Nov 21 10:47:25 PST 2024


Author: Peter Klausler
Date: 2024-11-21T10:47:21-08:00
New Revision: 3f594741cf8e1537fb25f84ef3cf2245b08d8089

URL: https://github.com/llvm/llvm-project/commit/3f594741cf8e1537fb25f84ef3cf2245b08d8089
DIFF: https://github.com/llvm/llvm-project/commit/3f594741cf8e1537fb25f84ef3cf2245b08d8089.diff

LOG: [flang] Fix implementation of Kahan summation (#116897)

In the runtime's implementation of floating-point SUM, the
implementation of Kahan's algorithm for increased precision is
incorrect. The running correction factor should be subtracted from each
new data item, not added to it. This fix ensures that the sum of 100M
random default real values between 0. and 1. is close to 5.E7.

See https://en.wikipedia.org/wiki/Kahan_summation_algorithm.

Added: 
    

Modified: 
    flang/lib/Evaluate/fold-matmul.h
    flang/lib/Evaluate/fold-real.cpp
    flang/lib/Evaluate/fold-reduction.h
    flang/runtime/sum.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h
index be9c547d45286c..c3d65a90409098 100644
--- a/flang/lib/Evaluate/fold-matmul.h
+++ b/flang/lib/Evaluate/fold-matmul.h
@@ -61,7 +61,7 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
           auto product{aElt.Multiply(bElt)};
           overflow |= product.flags.test(RealFlag::Overflow);
           if constexpr (useKahanSummation) {
-            auto next{correction.Add(product.value, rounding)};
+            auto next{product.value.Subtract(correction, rounding)};
             overflow |= next.flags.test(RealFlag::Overflow);
             auto added{sum.Add(next.value, rounding)};
             overflow |= added.flags.test(RealFlag::Overflow);

diff  --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 0b79a417942a45..6fb5249c8a5e2e 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -78,7 +78,7 @@ template <int KIND> class Norm2Accumulator {
       auto scaled{item.Divide(scale).value};
       auto square{scaled.Multiply(scaled).value};
       if constexpr (useKahanSummation) {
-        auto next{square.Add(correction_, rounding_)};
+        auto next{square.Subtract(correction_, rounding_)};
         overflow_ |= next.flags.test(RealFlag::Overflow);
         auto sum{element.Add(next.value, rounding_)};
         overflow_ |= sum.flags.test(RealFlag::Overflow);

diff  --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index 8ca0794ab0fc7c..b1b81d8740d3f3 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -47,7 +47,7 @@ static Expr<T> FoldDotProduct(
       const auto &rounding{context.targetCharacteristics().roundingMode()};
       for (const Element &x : cProducts.values()) {
         if constexpr (useKahanSummation) {
-          auto next{correction.Add(x, rounding)};
+          auto next{x.Subtract(correction, rounding)};
           overflow |= next.flags.test(RealFlag::Overflow);
           auto added{sum.Add(next.value, rounding)};
           overflow |= added.flags.test(RealFlag::Overflow);
@@ -90,7 +90,7 @@ static Expr<T> FoldDotProduct(
       const auto &rounding{context.targetCharacteristics().roundingMode()};
       for (const Element &x : cProducts.values()) {
         if constexpr (useKahanSummation) {
-          auto next{correction.Add(x, rounding)};
+          auto next{x.Subtract(correction, rounding)};
           overflow |= next.flags.test(RealFlag::Overflow);
           auto added{sum.Add(next.value, rounding)};
           overflow |= added.flags.test(RealFlag::Overflow);
@@ -348,7 +348,7 @@ template <typename T> class SumAccumulator {
       overflow_ |= sum.overflow;
       element = sum.value;
     } else { // Real & Complex: use Kahan summation
-      auto next{array_.At(at).Add(correction_, rounding_)};
+      auto next{array_.At(at).Subtract(correction_, rounding_)};
       overflow_ |= next.flags.test(RealFlag::Overflow);
       auto sum{element.Add(next.value, rounding_)};
       overflow_ |= sum.flags.test(RealFlag::Overflow);

diff  --git a/flang/runtime/sum.cpp b/flang/runtime/sum.cpp
index 04241443275eb9..10b81242546521 100644
--- a/flang/runtime/sum.cpp
+++ b/flang/runtime/sum.cpp
@@ -53,7 +53,7 @@ template <typename INTERMEDIATE> class RealSumAccumulator {
   }
   template <typename A> RT_API_ATTRS bool Accumulate(A x) {
     // Kahan summation
-    auto next{x + correction_};
+    auto next{x - correction_};
     auto oldSum{sum_};
     sum_ += next;
     correction_ = (sum_ - oldSum) - next; // algebraically zero


        


More information about the flang-commits mailing list