[flang-commits] [flang] [llvm] [flang] Fix spurious NaN result from infinite Kahan summation (PR #175373)
via flang-commits
flang-commits at lists.llvm.org
Sat Jan 10 09:55:27 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-semantics
Author: Peter Klausler (klausler)
<details>
<summary>Changes</summary>
There are six instances of Kahan's extended precision summation algorithm in flang/flang-rt, and they share a bug: the calculation of the correction value produces a Nan due to the subtraction Inf-Inf after the accumulation saturates to Inf. This leads to the surprising Nan result from SUM([Inf, 0.]).
This bug doesn't affect run-time calculation of SUM when optimization is enabled -- lowering emits an open-coded SUM that lacks Kahan summation -- but it does affect compilation-time folding and -O0 runtime results.
Fix the one instance of Kahan summation in the runtime, and consolidate the other five instances in Evaluate into one new member function, also corrected.
Fixes https://github.com/llvm/llvm-project/issues/89528.
---
Full diff: https://github.com/llvm/llvm-project/pull/175373.diff
9 Files Affected:
- (modified) flang-rt/lib/runtime/sum.cpp (+9-3)
- (modified) flang/include/flang/Evaluate/complex.h (+3)
- (modified) flang/include/flang/Evaluate/real.h (+2)
- (modified) flang/lib/Evaluate/complex.cpp (+11)
- (modified) flang/lib/Evaluate/fold-matmul.h (+3-8)
- (modified) flang/lib/Evaluate/fold-real.cpp (+1-6)
- (modified) flang/lib/Evaluate/fold-reduction.h (+7-23)
- (modified) flang/lib/Evaluate/real.cpp (+18)
- (added) flang/test/Evaluate/bug89528.f90 (+4)
``````````diff
diff --git a/flang-rt/lib/runtime/sum.cpp b/flang-rt/lib/runtime/sum.cpp
index a76e228f18a4e..0c540606f27c3 100644
--- a/flang-rt/lib/runtime/sum.cpp
+++ b/flang-rt/lib/runtime/sum.cpp
@@ -54,9 +54,15 @@ template <typename INTERMEDIATE> class RealSumAccumulator {
template <typename A> RT_API_ATTRS bool Accumulate(A x) {
// Kahan summation
auto next{x - correction_};
- auto oldSum{sum_};
- sum_ += next;
- correction_ = (sum_ - oldSum) - next; // algebraically zero
+ if (next != next) {
+ // Avoid propagating an accidental Nan from Inf-Inf in corrections
+ sum_ += x;
+ correction_ = 0;
+ } else {
+ auto oldSum{sum_};
+ sum_ += next;
+ correction_ = (sum_ - oldSum) - next; // algebraically zero
+ }
return true;
}
template <typename A>
diff --git a/flang/include/flang/Evaluate/complex.h b/flang/include/flang/Evaluate/complex.h
index 720ccaf512df6..9781db9a25a64 100644
--- a/flang/include/flang/Evaluate/complex.h
+++ b/flang/include/flang/Evaluate/complex.h
@@ -77,6 +77,9 @@ template <typename REAL_TYPE> class Complex {
Rounding rounding = TargetCharacteristics::defaultRounding) const;
ValueWithRealFlags<Complex> Divide(const Complex &,
Rounding rounding = TargetCharacteristics::defaultRounding) const;
+ ValueWithRealFlags<Complex> KahanSummation(const Complex &,
+ Complex &correction,
+ Rounding rounding = TargetCharacteristics::defaultRounding) const;
// ABS/CABS = HYPOT(re_, imag_) = SQRT(re_**2 + im_**2)
ValueWithRealFlags<Part> ABS(
diff --git a/flang/include/flang/Evaluate/real.h b/flang/include/flang/Evaluate/real.h
index dcd74073a4737..c0a966820d13e 100644
--- a/flang/include/flang/Evaluate/real.h
+++ b/flang/include/flang/Evaluate/real.h
@@ -175,6 +175,8 @@ template <typename WORD, int PREC> class Real {
Rounding rounding = TargetCharacteristics::defaultRounding) const;
ValueWithRealFlags<Real> MODULO(const Real &,
Rounding rounding = TargetCharacteristics::defaultRounding) const;
+ ValueWithRealFlags<Real> KahanSummation(const Real &, Real &correction,
+ Rounding rounding = TargetCharacteristics::defaultRounding) const;
template <typename INT> constexpr INT EXPONENT() const {
if (Exponent() == maxExponent) {
diff --git a/flang/lib/Evaluate/complex.cpp b/flang/lib/Evaluate/complex.cpp
index ab83f193e3f3e..a245fb38c82b9 100644
--- a/flang/lib/Evaluate/complex.cpp
+++ b/flang/lib/Evaluate/complex.cpp
@@ -100,6 +100,17 @@ ValueWithRealFlags<Complex<R>> Complex<R>::Divide(
return {Complex{re, im}, flags};
}
+template <typename R>
+ValueWithRealFlags<Complex<R>> Complex<R>::KahanSummation(
+ const Complex &that, Complex &correction, Rounding rounding) const {
+ RealFlags flags;
+ Part reSum{re_.KahanSummation(that.re_, correction.re_, rounding)
+ .AccumulateFlags(flags)};
+ Part imSum{im_.KahanSummation(that.im_, correction.im_, rounding)
+ .AccumulateFlags(flags)};
+ return {Complex{reSum, imSum}, flags};
+}
+
template <typename R> std::string Complex<R>::DumpHexadecimal() const {
std::string result{'('};
result += re_.DumpHexadecimal();
diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h
index ae9221f9ce042..a8a24c09774e8 100644
--- a/flang/lib/Evaluate/fold-matmul.h
+++ b/flang/lib/Evaluate/fold-matmul.h
@@ -61,18 +61,13 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
auto product{aElt.Multiply(bElt)};
overflow |= product.flags.test(RealFlag::Overflow);
if constexpr (useKahanSummation) {
- auto next{product.value.Subtract(correction, rounding)};
- overflow |= next.flags.test(RealFlag::Overflow);
- auto added{sum.Add(next.value, rounding)};
+ auto added{sum.KahanSummation(product.value, correction)};
overflow |= added.flags.test(RealFlag::Overflow);
- correction = added.value.Subtract(sum, rounding)
- .value.Subtract(next.value, rounding)
- .value;
- sum = std::move(added.value);
+ sum = added.value;
} else {
auto added{sum.Add(product.value)};
overflow |= added.flags.test(RealFlag::Overflow);
- sum = std::move(added.value);
+ sum = added.value;
}
} else if constexpr (T::category == TypeCategory::Integer) {
auto product{aElt.MultiplySigned(bElt)};
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 1ff941053a82e..cf83342b5bd9d 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -77,13 +77,8 @@ template <int KIND> class Norm2Accumulator {
auto scaled{item.Divide(scale).value};
auto square{scaled.Multiply(scaled).value};
if constexpr (useKahanSummation) {
- auto next{square.Subtract(correction_, rounding_)};
- overflow_ |= next.flags.test(RealFlag::Overflow);
- auto sum{element.Add(next.value, rounding_)};
+ auto sum{element.KahanSummation(square, correction_, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
- correction_ = sum.value.Subtract(element, rounding_)
- .value.Subtract(next.value, rounding_)
- .value;
element = sum.value;
} else {
auto sum{element.Add(square, rounding_)};
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index fe897393fe13e..a068364135295 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -47,18 +47,13 @@ static Expr<T> FoldDotProduct(
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
- auto next{x.Subtract(correction, rounding)};
- overflow |= next.flags.test(RealFlag::Overflow);
- auto added{sum.Add(next.value, rounding)};
+ auto added{sum.KahanSummation(x, correction, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
- correction = added.value.Subtract(sum, rounding)
- .value.Subtract(next.value, rounding)
- .value;
- sum = std::move(added.value);
+ sum = added.value;
} else {
auto added{sum.Add(x, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
- sum = std::move(added.value);
+ sum = added.value;
}
}
} else if constexpr (T::category == TypeCategory::Logical) {
@@ -97,18 +92,13 @@ static Expr<T> FoldDotProduct(
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
- auto next{x.Subtract(correction, rounding)};
- overflow |= next.flags.test(RealFlag::Overflow);
- auto added{sum.Add(next.value, rounding)};
+ auto added{sum.KahanSummation(x, correction, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
- correction = added.value.Subtract(sum, rounding)
- .value.Subtract(next.value, rounding)
- .value;
- sum = std::move(added.value);
+ sum = added.value;
} else {
auto added{sum.Add(x, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
- sum = std::move(added.value);
+ sum = added.value;
}
}
}
@@ -357,14 +347,8 @@ template <typename T> class SumAccumulator {
} else if constexpr (T::category == TypeCategory::Unsigned) {
element = element.AddUnsigned(array_.At(at)).value;
} else { // Real & Complex: use Kahan summation
- auto next{array_.At(at).Subtract(correction_, rounding_)};
- overflow_ |= next.flags.test(RealFlag::Overflow);
- auto sum{element.Add(next.value, rounding_)};
+ auto sum{element.KahanSummation(array_.At(at), correction_, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
- // correction = (sum - element) - next; algebraically zero
- correction_ = sum.value.Subtract(element, rounding_)
- .value.Subtract(next.value, rounding_)
- .value;
element = sum.value;
}
}
diff --git a/flang/lib/Evaluate/real.cpp b/flang/lib/Evaluate/real.cpp
index 6e6b9f3ac77c2..eb335ce328517 100644
--- a/flang/lib/Evaluate/real.cpp
+++ b/flang/lib/Evaluate/real.cpp
@@ -465,6 +465,24 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::MODULO(
return result;
}
+template <typename W, int P>
+ValueWithRealFlags<Real<W, P>> Real<W, P>::KahanSummation(
+ const Real &y, Real &correction, Rounding rounding) const {
+ Real next{y.Subtract(correction, rounding).value};
+ if (next.IsNotANumber()) {
+ // Avoid propagating an accidental NaN from Inf-Inf in corrections
+ correction = Real{}; // 0.
+ return Add(y, rounding);
+ } else {
+ auto sum{Add(next, rounding)};
+ // correction = (sum - *this) - next; algebraically zero
+ correction = sum.value.Subtract(*this, rounding)
+ .value.Subtract(next, rounding)
+ .value;
+ return sum;
+ }
+}
+
template <typename W, int P>
ValueWithRealFlags<Real<W, P>> Real<W, P>::DIM(
const Real &y, Rounding rounding) const {
diff --git a/flang/test/Evaluate/bug89528.f90 b/flang/test/Evaluate/bug89528.f90
new file mode 100644
index 0000000000000..991e9c04812c3
--- /dev/null
+++ b/flang/test/Evaluate/bug89528.f90
@@ -0,0 +1,4 @@
+!RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
+!CHECK: REAL :: avoidkahannan = (1._4/0.)
+real :: avoidKahanNaN = sum([1./0., 0.]) ! Inf, not NaN
+end
``````````
</details>
https://github.com/llvm/llvm-project/pull/175373
More information about the flang-commits
mailing list