[flang-commits] [flang] [llvm] [flang] Fix spurious NaN result from infinite Kahan summation (PR #175373)

via flang-commits flang-commits at lists.llvm.org
Sat Jan 10 09:55:27 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-semantics

Author: Peter Klausler (klausler)

<details>
<summary>Changes</summary>

There are six instances of Kahan's extended precision summation algorithm in flang/flang-rt, and they share a bug: the calculation of the correction value produces a Nan due to the subtraction Inf-Inf after the accumulation saturates to Inf.  This leads to the surprising Nan result from SUM([Inf, 0.]).

This bug doesn't affect run-time calculation of SUM when optimization is enabled -- lowering emits an open-coded SUM that lacks Kahan summation -- but it does affect compilation-time folding and -O0 runtime results.

Fix the one instance of Kahan summation in the runtime, and consolidate the other five instances in Evaluate into one new member function, also corrected.

Fixes https://github.com/llvm/llvm-project/issues/89528.

---
Full diff: https://github.com/llvm/llvm-project/pull/175373.diff


9 Files Affected:

- (modified) flang-rt/lib/runtime/sum.cpp (+9-3) 
- (modified) flang/include/flang/Evaluate/complex.h (+3) 
- (modified) flang/include/flang/Evaluate/real.h (+2) 
- (modified) flang/lib/Evaluate/complex.cpp (+11) 
- (modified) flang/lib/Evaluate/fold-matmul.h (+3-8) 
- (modified) flang/lib/Evaluate/fold-real.cpp (+1-6) 
- (modified) flang/lib/Evaluate/fold-reduction.h (+7-23) 
- (modified) flang/lib/Evaluate/real.cpp (+18) 
- (added) flang/test/Evaluate/bug89528.f90 (+4) 


``````````diff
diff --git a/flang-rt/lib/runtime/sum.cpp b/flang-rt/lib/runtime/sum.cpp
index a76e228f18a4e..0c540606f27c3 100644
--- a/flang-rt/lib/runtime/sum.cpp
+++ b/flang-rt/lib/runtime/sum.cpp
@@ -54,9 +54,15 @@ template <typename INTERMEDIATE> class RealSumAccumulator {
   template <typename A> RT_API_ATTRS bool Accumulate(A x) {
     // Kahan summation
     auto next{x - correction_};
-    auto oldSum{sum_};
-    sum_ += next;
-    correction_ = (sum_ - oldSum) - next; // algebraically zero
+    if (next != next) {
+      // Avoid propagating an accidental Nan from Inf-Inf in corrections
+      sum_ += x;
+      correction_ = 0;
+    } else {
+      auto oldSum{sum_};
+      sum_ += next;
+      correction_ = (sum_ - oldSum) - next; // algebraically zero
+    }
     return true;
   }
   template <typename A>
diff --git a/flang/include/flang/Evaluate/complex.h b/flang/include/flang/Evaluate/complex.h
index 720ccaf512df6..9781db9a25a64 100644
--- a/flang/include/flang/Evaluate/complex.h
+++ b/flang/include/flang/Evaluate/complex.h
@@ -77,6 +77,9 @@ template <typename REAL_TYPE> class Complex {
       Rounding rounding = TargetCharacteristics::defaultRounding) const;
   ValueWithRealFlags<Complex> Divide(const Complex &,
       Rounding rounding = TargetCharacteristics::defaultRounding) const;
+  ValueWithRealFlags<Complex> KahanSummation(const Complex &,
+      Complex &correction,
+      Rounding rounding = TargetCharacteristics::defaultRounding) const;
 
   // ABS/CABS = HYPOT(re_, imag_) = SQRT(re_**2 + im_**2)
   ValueWithRealFlags<Part> ABS(
diff --git a/flang/include/flang/Evaluate/real.h b/flang/include/flang/Evaluate/real.h
index dcd74073a4737..c0a966820d13e 100644
--- a/flang/include/flang/Evaluate/real.h
+++ b/flang/include/flang/Evaluate/real.h
@@ -175,6 +175,8 @@ template <typename WORD, int PREC> class Real {
       Rounding rounding = TargetCharacteristics::defaultRounding) const;
   ValueWithRealFlags<Real> MODULO(const Real &,
       Rounding rounding = TargetCharacteristics::defaultRounding) const;
+  ValueWithRealFlags<Real> KahanSummation(const Real &, Real &correction,
+      Rounding rounding = TargetCharacteristics::defaultRounding) const;
 
   template <typename INT> constexpr INT EXPONENT() const {
     if (Exponent() == maxExponent) {
diff --git a/flang/lib/Evaluate/complex.cpp b/flang/lib/Evaluate/complex.cpp
index ab83f193e3f3e..a245fb38c82b9 100644
--- a/flang/lib/Evaluate/complex.cpp
+++ b/flang/lib/Evaluate/complex.cpp
@@ -100,6 +100,17 @@ ValueWithRealFlags<Complex<R>> Complex<R>::Divide(
   return {Complex{re, im}, flags};
 }
 
+template <typename R>
+ValueWithRealFlags<Complex<R>> Complex<R>::KahanSummation(
+    const Complex &that, Complex &correction, Rounding rounding) const {
+  RealFlags flags;
+  Part reSum{re_.KahanSummation(that.re_, correction.re_, rounding)
+          .AccumulateFlags(flags)};
+  Part imSum{im_.KahanSummation(that.im_, correction.im_, rounding)
+          .AccumulateFlags(flags)};
+  return {Complex{reSum, imSum}, flags};
+}
+
 template <typename R> std::string Complex<R>::DumpHexadecimal() const {
   std::string result{'('};
   result += re_.DumpHexadecimal();
diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h
index ae9221f9ce042..a8a24c09774e8 100644
--- a/flang/lib/Evaluate/fold-matmul.h
+++ b/flang/lib/Evaluate/fold-matmul.h
@@ -61,18 +61,13 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
           auto product{aElt.Multiply(bElt)};
           overflow |= product.flags.test(RealFlag::Overflow);
           if constexpr (useKahanSummation) {
-            auto next{product.value.Subtract(correction, rounding)};
-            overflow |= next.flags.test(RealFlag::Overflow);
-            auto added{sum.Add(next.value, rounding)};
+            auto added{sum.KahanSummation(product.value, correction)};
             overflow |= added.flags.test(RealFlag::Overflow);
-            correction = added.value.Subtract(sum, rounding)
-                             .value.Subtract(next.value, rounding)
-                             .value;
-            sum = std::move(added.value);
+            sum = added.value;
           } else {
             auto added{sum.Add(product.value)};
             overflow |= added.flags.test(RealFlag::Overflow);
-            sum = std::move(added.value);
+            sum = added.value;
           }
         } else if constexpr (T::category == TypeCategory::Integer) {
           auto product{aElt.MultiplySigned(bElt)};
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 1ff941053a82e..cf83342b5bd9d 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -77,13 +77,8 @@ template <int KIND> class Norm2Accumulator {
       auto scaled{item.Divide(scale).value};
       auto square{scaled.Multiply(scaled).value};
       if constexpr (useKahanSummation) {
-        auto next{square.Subtract(correction_, rounding_)};
-        overflow_ |= next.flags.test(RealFlag::Overflow);
-        auto sum{element.Add(next.value, rounding_)};
+        auto sum{element.KahanSummation(square, correction_, rounding_)};
         overflow_ |= sum.flags.test(RealFlag::Overflow);
-        correction_ = sum.value.Subtract(element, rounding_)
-                          .value.Subtract(next.value, rounding_)
-                          .value;
         element = sum.value;
       } else {
         auto sum{element.Add(square, rounding_)};
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index fe897393fe13e..a068364135295 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -47,18 +47,13 @@ static Expr<T> FoldDotProduct(
       const auto &rounding{context.targetCharacteristics().roundingMode()};
       for (const Element &x : cProducts.values()) {
         if constexpr (useKahanSummation) {
-          auto next{x.Subtract(correction, rounding)};
-          overflow |= next.flags.test(RealFlag::Overflow);
-          auto added{sum.Add(next.value, rounding)};
+          auto added{sum.KahanSummation(x, correction, rounding)};
           overflow |= added.flags.test(RealFlag::Overflow);
-          correction = added.value.Subtract(sum, rounding)
-                           .value.Subtract(next.value, rounding)
-                           .value;
-          sum = std::move(added.value);
+          sum = added.value;
         } else {
           auto added{sum.Add(x, rounding)};
           overflow |= added.flags.test(RealFlag::Overflow);
-          sum = std::move(added.value);
+          sum = added.value;
         }
       }
     } else if constexpr (T::category == TypeCategory::Logical) {
@@ -97,18 +92,13 @@ static Expr<T> FoldDotProduct(
       const auto &rounding{context.targetCharacteristics().roundingMode()};
       for (const Element &x : cProducts.values()) {
         if constexpr (useKahanSummation) {
-          auto next{x.Subtract(correction, rounding)};
-          overflow |= next.flags.test(RealFlag::Overflow);
-          auto added{sum.Add(next.value, rounding)};
+          auto added{sum.KahanSummation(x, correction, rounding)};
           overflow |= added.flags.test(RealFlag::Overflow);
-          correction = added.value.Subtract(sum, rounding)
-                           .value.Subtract(next.value, rounding)
-                           .value;
-          sum = std::move(added.value);
+          sum = added.value;
         } else {
           auto added{sum.Add(x, rounding)};
           overflow |= added.flags.test(RealFlag::Overflow);
-          sum = std::move(added.value);
+          sum = added.value;
         }
       }
     }
@@ -357,14 +347,8 @@ template <typename T> class SumAccumulator {
     } else if constexpr (T::category == TypeCategory::Unsigned) {
       element = element.AddUnsigned(array_.At(at)).value;
     } else { // Real & Complex: use Kahan summation
-      auto next{array_.At(at).Subtract(correction_, rounding_)};
-      overflow_ |= next.flags.test(RealFlag::Overflow);
-      auto sum{element.Add(next.value, rounding_)};
+      auto sum{element.KahanSummation(array_.At(at), correction_, rounding_)};
       overflow_ |= sum.flags.test(RealFlag::Overflow);
-      // correction = (sum - element) - next; algebraically zero
-      correction_ = sum.value.Subtract(element, rounding_)
-                        .value.Subtract(next.value, rounding_)
-                        .value;
       element = sum.value;
     }
   }
diff --git a/flang/lib/Evaluate/real.cpp b/flang/lib/Evaluate/real.cpp
index 6e6b9f3ac77c2..eb335ce328517 100644
--- a/flang/lib/Evaluate/real.cpp
+++ b/flang/lib/Evaluate/real.cpp
@@ -465,6 +465,24 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::MODULO(
   return result;
 }
 
+template <typename W, int P>
+ValueWithRealFlags<Real<W, P>> Real<W, P>::KahanSummation(
+    const Real &y, Real &correction, Rounding rounding) const {
+  Real next{y.Subtract(correction, rounding).value};
+  if (next.IsNotANumber()) {
+    // Avoid propagating an accidental NaN from Inf-Inf in corrections
+    correction = Real{}; // 0.
+    return Add(y, rounding);
+  } else {
+    auto sum{Add(next, rounding)};
+    // correction = (sum - *this) - next; algebraically zero
+    correction = sum.value.Subtract(*this, rounding)
+                     .value.Subtract(next, rounding)
+                     .value;
+    return sum;
+  }
+}
+
 template <typename W, int P>
 ValueWithRealFlags<Real<W, P>> Real<W, P>::DIM(
     const Real &y, Rounding rounding) const {
diff --git a/flang/test/Evaluate/bug89528.f90 b/flang/test/Evaluate/bug89528.f90
new file mode 100644
index 0000000000000..991e9c04812c3
--- /dev/null
+++ b/flang/test/Evaluate/bug89528.f90
@@ -0,0 +1,4 @@
+!RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
+!CHECK: REAL :: avoidkahannan = (1._4/0.)
+real :: avoidKahanNaN = sum([1./0., 0.]) ! Inf, not NaN
+end

``````````

</details>


https://github.com/llvm/llvm-project/pull/175373


More information about the flang-commits mailing list