[flang-commits] [flang] 3502d34 - [flang] Adjust transformational folding to match runtime (#90132)

via flang-commits flang-commits at lists.llvm.org
Wed May 1 14:06:37 PDT 2024


Author: Peter Klausler
Date: 2024-05-01T14:06:32-07:00
New Revision: 3502d340c9276f1828da9db72f83e5e25b163b8b

URL: https://github.com/llvm/llvm-project/commit/3502d340c9276f1828da9db72f83e5e25b163b8b
DIFF: https://github.com/llvm/llvm-project/commit/3502d340c9276f1828da9db72f83e5e25b163b8b.diff

LOG: [flang] Adjust transformational folding to match runtime (#90132)

The transformational intrinsic functions MATMUL, DOT_PRODUCT, and NORM2
all involve summing up intermediate products into accumulators. In the
constant folding library, this is done with extended precision Kahan
summation for REAL and COMPLEX arguments, but in the runtime
implementations it is not, and this leads to discrepancies between
folded results and dynamic results.

Disable the use of Kahan summation in folding to resolve these
discrepancies, but don't discard the code, in case we want to add Kahan
summation in the runtime for some or all of these intrinsic functions.

Added: 
    

Modified: 
    flang/lib/Evaluate/fold-implementation.h
    flang/lib/Evaluate/fold-matmul.h
    flang/lib/Evaluate/fold-real.cpp
    flang/lib/Evaluate/fold-reduction.h

Removed: 
    


################################################################################
diff  --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index 093f26bea1a44f..2c0e0883207e1b 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -45,6 +45,12 @@
 
 namespace Fortran::evaluate {
 
+// Don't use Kahan extended precision summation any more when folding
+// transformational intrinsic functions other than SUM, since it is
+// not used in the runtime implementations of those functions and we
+// want results to match.
+static constexpr bool useKahanSummation{false};
+
 // Utilities
 template <typename T> class Folder {
 public:

diff  --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h
index 27b6db1fd8bf02..bd61969a822c3b 100644
--- a/flang/lib/Evaluate/fold-matmul.h
+++ b/flang/lib/Evaluate/fold-matmul.h
@@ -58,18 +58,25 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
         Element bElt{mb->At(bAt)};
         if constexpr (T::category == TypeCategory::Real ||
             T::category == TypeCategory::Complex) {
-          // Kahan summation
-          auto product{aElt.Multiply(bElt, rounding)};
+          auto product{aElt.Multiply(bElt)};
           overflow |= product.flags.test(RealFlag::Overflow);
-          auto next{correction.Add(product.value, rounding)};
-          overflow |= next.flags.test(RealFlag::Overflow);
-          auto added{sum.Add(next.value, rounding)};
-          overflow |= added.flags.test(RealFlag::Overflow);
-          correction = added.value.Subtract(sum, rounding)
-                           .value.Subtract(next.value, rounding)
-                           .value;
-          sum = std::move(added.value);
+          if constexpr (useKahanSummation) {
+            auto next{correction.Add(product.value, rounding)};
+            overflow |= next.flags.test(RealFlag::Overflow);
+            auto added{sum.Add(next.value, rounding)};
+            overflow |= added.flags.test(RealFlag::Overflow);
+            correction = added.value.Subtract(sum, rounding)
+                             .value.Subtract(next.value, rounding)
+                             .value;
+            sum = std::move(added.value);
+          } else {
+            auto added{sum.Add(product.value)};
+            overflow |= added.flags.test(RealFlag::Overflow);
+            sum = std::move(added.value);
+          }
         } else if constexpr (T::category == TypeCategory::Integer) {
+          // Don't use Kahan summation in numeric MATMUL folding;
+          // the runtime doesn't use it, and results should match.
           auto product{aElt.MultiplySigned(bElt)};
           overflow |= product.SignedMultiplicationOverflowed();
           auto added{sum.AddSigned(product.lower)};

diff  --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index fd37437c643aa2..4df709d3d2c215 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -54,7 +54,7 @@ template <int KIND> class Norm2Accumulator {
       : array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
   void operator()(
       Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
-    // Kahan summation of scaled elements:
+    // Summation of scaled elements:
     // Naively,
     //   NORM2(A(:)) = SQRT(SUM(A(:)**2))
     // For any T > 0, we have mathematically
@@ -76,24 +76,27 @@ template <int KIND> class Norm2Accumulator {
       auto item{array_.At(at)};
       auto scaled{item.Divide(scale).value};
       auto square{scaled.Multiply(scaled).value};
-      auto next{square.Add(correction_, rounding_)};
-      overflow_ |= next.flags.test(RealFlag::Overflow);
-      auto sum{element.Add(next.value, rounding_)};
-      overflow_ |= sum.flags.test(RealFlag::Overflow);
-      correction_ = sum.value.Subtract(element, rounding_)
-                        .value.Subtract(next.value, rounding_)
-                        .value;
-      element = sum.value;
+      if constexpr (useKahanSummation) {
+        auto next{square.Add(correction_, rounding_)};
+        overflow_ |= next.flags.test(RealFlag::Overflow);
+        auto sum{element.Add(next.value, rounding_)};
+        overflow_ |= sum.flags.test(RealFlag::Overflow);
+        correction_ = sum.value.Subtract(element, rounding_)
+                          .value.Subtract(next.value, rounding_)
+                          .value;
+        element = sum.value;
+      } else {
+        auto sum{element.Add(square, rounding_)};
+        overflow_ |= sum.flags.test(RealFlag::Overflow);
+        element = sum.value;
+      }
     }
   }
   bool overflow() const { return overflow_; }
   void Done(Scalar<T> &result) {
-    // result+correction == SUM((data(:)/maxAbs)**2)
-    // result = maxAbs * SQRT(result+correction)
-    auto corrected{result.Add(correction_, rounding_)};
-    overflow_ |= corrected.flags.test(RealFlag::Overflow);
-    correction_ = Scalar<T>{};
-    auto root{corrected.value.SQRT().value};
+    // incoming result = SUM((data(:)/maxAbs)**2)
+    // outgoing result = maxAbs * SQRT(result)
+    auto root{result.SQRT().value};
     auto product{root.Multiply(maxAbs_.At(maxAbsAt_))};
     maxAbs_.IncrementSubscripts(maxAbsAt_);
     overflow_ |= product.flags.test(RealFlag::Overflow);

diff  --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index c84d35734ab5af..ae17770dc2961e 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -43,17 +43,23 @@ static Expr<T> FoldDotProduct(
       Expr<T> products{Fold(
           context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
       Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
-      Element correction{}; // Use Kahan summation for greater precision.
+      [[maybe_unused]] Element correction{};
       const auto &rounding{context.targetCharacteristics().roundingMode()};
       for (const Element &x : cProducts.values()) {
-        auto next{correction.Add(x, rounding)};
-        overflow |= next.flags.test(RealFlag::Overflow);
-        auto added{sum.Add(next.value, rounding)};
-        overflow |= added.flags.test(RealFlag::Overflow);
-        correction = added.value.Subtract(sum, rounding)
-                         .value.Subtract(next.value, rounding)
-                         .value;
-        sum = std::move(added.value);
+        if constexpr (useKahanSummation) {
+          auto next{correction.Add(x, rounding)};
+          overflow |= next.flags.test(RealFlag::Overflow);
+          auto added{sum.Add(next.value, rounding)};
+          overflow |= added.flags.test(RealFlag::Overflow);
+          correction = added.value.Subtract(sum, rounding)
+                           .value.Subtract(next.value, rounding)
+                           .value;
+          sum = std::move(added.value);
+        } else {
+          auto added{sum.Add(x, rounding)};
+          overflow |= added.flags.test(RealFlag::Overflow);
+          sum = std::move(added.value);
+        }
       }
     } else if constexpr (T::category == TypeCategory::Logical) {
       Expr<T> conjunctions{Fold(context,
@@ -80,17 +86,23 @@ static Expr<T> FoldDotProduct(
       Expr<T> products{
           Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
       Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
-      Element correction{}; // Use Kahan summation for greater precision.
+      [[maybe_unused]] Element correction{};
       const auto &rounding{context.targetCharacteristics().roundingMode()};
       for (const Element &x : cProducts.values()) {
-        auto next{correction.Add(x, rounding)};
-        overflow |= next.flags.test(RealFlag::Overflow);
-        auto added{sum.Add(next.value, rounding)};
-        overflow |= added.flags.test(RealFlag::Overflow);
-        correction = added.value.Subtract(sum, rounding)
-                         .value.Subtract(next.value, rounding)
-                         .value;
-        sum = std::move(added.value);
+        if constexpr (useKahanSummation) {
+          auto next{correction.Add(x, rounding)};
+          overflow |= next.flags.test(RealFlag::Overflow);
+          auto added{sum.Add(next.value, rounding)};
+          overflow |= added.flags.test(RealFlag::Overflow);
+          correction = added.value.Subtract(sum, rounding)
+                           .value.Subtract(next.value, rounding)
+                           .value;
+          sum = std::move(added.value);
+        } else {
+          auto added{sum.Add(x, rounding)};
+          overflow |= added.flags.test(RealFlag::Overflow);
+          sum = std::move(added.value);
+        }
       }
     }
     if (overflow) {


        


More information about the flang-commits mailing list