[flang-commits] [flang] [flang] Adjust transformational folding to match runtime (PR #90132)
via flang-commits
flang-commits at lists.llvm.org
Thu Apr 25 15:17:50 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-semantics
Author: Peter Klausler (klausler)
<details>
<summary>Changes</summary>
The transformational intrinsic functions MATMUL, DOT_PRODUCT, and NORM2 all involve summing up intermediate products into accumulators. In the constant folding library, this is done with extended precision Kahan summation for REAL and COMPLEX arguments, but in the runtime implementations it is not, and this leads to discrepancies between folded results and dynamic results.
Disable the use of Kahan summation in folding to resolve these discrepancies, but don't discard the code, in case we want to add Kahan summation in the runtime for some or all of these intrinsic functions.
---
Full diff: https://github.com/llvm/llvm-project/pull/90132.diff
4 Files Affected:
- (modified) flang/lib/Evaluate/fold-implementation.h (+6)
- (modified) flang/lib/Evaluate/fold-matmul.h (+17-10)
- (modified) flang/lib/Evaluate/fold-real.cpp (+18-15)
- (modified) flang/lib/Evaluate/fold-reduction.h (+30-18)
``````````diff
diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index 093f26bea1a44f..2c0e0883207e1b 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -45,6 +45,12 @@
namespace Fortran::evaluate {
+// Don't use Kahan extended precision summation any more when folding
+// transformational intrinsic functions other than SUM, since it is
+// not used in the runtime implementations of those functions and we
+// want results to match.
+static constexpr bool useKahanSummation{false};
+
// Utilities
template <typename T> class Folder {
public:
diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h
index 27b6db1fd8bf02..bd61969a822c3b 100644
--- a/flang/lib/Evaluate/fold-matmul.h
+++ b/flang/lib/Evaluate/fold-matmul.h
@@ -58,18 +58,25 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
Element bElt{mb->At(bAt)};
if constexpr (T::category == TypeCategory::Real ||
T::category == TypeCategory::Complex) {
- // Kahan summation
- auto product{aElt.Multiply(bElt, rounding)};
+ auto product{aElt.Multiply(bElt)};
overflow |= product.flags.test(RealFlag::Overflow);
- auto next{correction.Add(product.value, rounding)};
- overflow |= next.flags.test(RealFlag::Overflow);
- auto added{sum.Add(next.value, rounding)};
- overflow |= added.flags.test(RealFlag::Overflow);
- correction = added.value.Subtract(sum, rounding)
- .value.Subtract(next.value, rounding)
- .value;
- sum = std::move(added.value);
+ if constexpr (useKahanSummation) {
+ auto next{correction.Add(product.value, rounding)};
+ overflow |= next.flags.test(RealFlag::Overflow);
+ auto added{sum.Add(next.value, rounding)};
+ overflow |= added.flags.test(RealFlag::Overflow);
+ correction = added.value.Subtract(sum, rounding)
+ .value.Subtract(next.value, rounding)
+ .value;
+ sum = std::move(added.value);
+ } else {
+ auto added{sum.Add(product.value)};
+ overflow |= added.flags.test(RealFlag::Overflow);
+ sum = std::move(added.value);
+ }
} else if constexpr (T::category == TypeCategory::Integer) {
+ // Don't use Kahan summation in numeric MATMUL folding;
+ // the runtime doesn't use it, and results should match.
auto product{aElt.MultiplySigned(bElt)};
overflow |= product.SignedMultiplicationOverflowed();
auto added{sum.AddSigned(product.lower)};
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index fd37437c643aa2..4df709d3d2c215 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -54,7 +54,7 @@ template <int KIND> class Norm2Accumulator {
: array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
void operator()(
Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
- // Kahan summation of scaled elements:
+ // Summation of scaled elements:
// Naively,
// NORM2(A(:)) = SQRT(SUM(A(:)**2))
// For any T > 0, we have mathematically
@@ -76,24 +76,27 @@ template <int KIND> class Norm2Accumulator {
auto item{array_.At(at)};
auto scaled{item.Divide(scale).value};
auto square{scaled.Multiply(scaled).value};
- auto next{square.Add(correction_, rounding_)};
- overflow_ |= next.flags.test(RealFlag::Overflow);
- auto sum{element.Add(next.value, rounding_)};
- overflow_ |= sum.flags.test(RealFlag::Overflow);
- correction_ = sum.value.Subtract(element, rounding_)
- .value.Subtract(next.value, rounding_)
- .value;
- element = sum.value;
+ if constexpr (useKahanSummation) {
+ auto next{square.Add(correction_, rounding_)};
+ overflow_ |= next.flags.test(RealFlag::Overflow);
+ auto sum{element.Add(next.value, rounding_)};
+ overflow_ |= sum.flags.test(RealFlag::Overflow);
+ correction_ = sum.value.Subtract(element, rounding_)
+ .value.Subtract(next.value, rounding_)
+ .value;
+ element = sum.value;
+ } else {
+ auto sum{element.Add(square, rounding_)};
+ overflow_ |= sum.flags.test(RealFlag::Overflow);
+ element = sum.value;
+ }
}
}
bool overflow() const { return overflow_; }
void Done(Scalar<T> &result) {
- // result+correction == SUM((data(:)/maxAbs)**2)
- // result = maxAbs * SQRT(result+correction)
- auto corrected{result.Add(correction_, rounding_)};
- overflow_ |= corrected.flags.test(RealFlag::Overflow);
- correction_ = Scalar<T>{};
- auto root{corrected.value.SQRT().value};
+ // incoming result = SUM((data(:)/maxAbs)**2)
+ // outgoing result = maxAbs * SQRT(result)
+ auto root{result.SQRT().value};
auto product{root.Multiply(maxAbs_.At(maxAbsAt_))};
maxAbs_.IncrementSubscripts(maxAbsAt_);
overflow_ |= product.flags.test(RealFlag::Overflow);
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index c84d35734ab5af..ae17770dc2961e 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -43,17 +43,23 @@ static Expr<T> FoldDotProduct(
Expr<T> products{Fold(
context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
- Element correction{}; // Use Kahan summation for greater precision.
+ [[maybe_unused]] Element correction{};
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
- auto next{correction.Add(x, rounding)};
- overflow |= next.flags.test(RealFlag::Overflow);
- auto added{sum.Add(next.value, rounding)};
- overflow |= added.flags.test(RealFlag::Overflow);
- correction = added.value.Subtract(sum, rounding)
- .value.Subtract(next.value, rounding)
- .value;
- sum = std::move(added.value);
+ if constexpr (useKahanSummation) {
+ auto next{correction.Add(x, rounding)};
+ overflow |= next.flags.test(RealFlag::Overflow);
+ auto added{sum.Add(next.value, rounding)};
+ overflow |= added.flags.test(RealFlag::Overflow);
+ correction = added.value.Subtract(sum, rounding)
+ .value.Subtract(next.value, rounding)
+ .value;
+ sum = std::move(added.value);
+ } else {
+ auto added{sum.Add(x, rounding)};
+ overflow |= added.flags.test(RealFlag::Overflow);
+ sum = std::move(added.value);
+ }
}
} else if constexpr (T::category == TypeCategory::Logical) {
Expr<T> conjunctions{Fold(context,
@@ -80,17 +86,23 @@ static Expr<T> FoldDotProduct(
Expr<T> products{
Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
- Element correction{}; // Use Kahan summation for greater precision.
+ [[maybe_unused]] Element correction{};
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
- auto next{correction.Add(x, rounding)};
- overflow |= next.flags.test(RealFlag::Overflow);
- auto added{sum.Add(next.value, rounding)};
- overflow |= added.flags.test(RealFlag::Overflow);
- correction = added.value.Subtract(sum, rounding)
- .value.Subtract(next.value, rounding)
- .value;
- sum = std::move(added.value);
+ if constexpr (useKahanSummation) {
+ auto next{correction.Add(x, rounding)};
+ overflow |= next.flags.test(RealFlag::Overflow);
+ auto added{sum.Add(next.value, rounding)};
+ overflow |= added.flags.test(RealFlag::Overflow);
+ correction = added.value.Subtract(sum, rounding)
+ .value.Subtract(next.value, rounding)
+ .value;
+ sum = std::move(added.value);
+ } else {
+ auto added{sum.Add(x, rounding)};
+ overflow |= added.flags.test(RealFlag::Overflow);
+ sum = std::move(added.value);
+ }
}
}
if (overflow) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/90132
More information about the flang-commits
mailing list