[flang-commits] [flang] [flang] Avoid needless overflow when folding NORM2 (PR #67499)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Fri Sep 29 09:56:15 PDT 2023


https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/67499

>From 2a9d514c06f2381f6b2d38ce59531d208f66427b Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Tue, 26 Sep 2023 15:34:03 -0700
Subject: [PATCH] [flang] Avoid needless overflow when folding NORM2

The code that folds the relatively new NORM2 intrinsic
function can produce overflow in cases where it's not
warranted.  Rearrange to NORM2 = M * SQRT((A(:)/M)**2)
where M is MAXVAL(ABS(A)).
---
 flang/lib/Evaluate/fold-real.cpp    | 28 ++++++++++++++++++++++------
 flang/lib/Evaluate/fold-reduction.h |  2 +-
 flang/test/Evaluate/fold-norm2.f90  | 13 ++++++++++---
 3 files changed, 33 insertions(+), 10 deletions(-)

diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 8e3ab1d8fd30b09..6bcc3ec73982157 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -52,15 +52,28 @@ template <int KIND> class Norm2Accumulator {
       const Constant<T> &array, const Constant<T> &maxAbs, Rounding rounding)
       : array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
   void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
-    // Kahan summation of scaled elements
+    // Kahan summation of scaled elements:
+    // Naively,
+    //   NORM2(A(:)) = SQRT(SUM(A(:)**2))
+    // For any T > 0, we have mathematically
+    //   SQRT(SUM(A(:)**2))
+    //     = SQRT(T**2 * (SUM(A(:)**2) / T**2))
+    //     = SQRT(T**2 * SUM(A(:)**2 / T**2))
+    //     = SQRT(T**2 * SUM((A(:)/T)**2))
+    //     = SQRT(T**2) * SQRT(SUM((A(:)/T)**2))
+    //     = T * SQRT(SUM((A(:)/T)**2))
+    // By letting T = MAXVAL(ABS(A)), we ensure that
+    // ALL(ABS(A(:)/T) <= 1), so ALL((A(:)/T)**2 <= 1), and the SUM will
+    // not overflow unless absolutely necessary.
     auto scale{maxAbs_.At(maxAbsAt_)};
     if (scale.IsZero()) {
-      // If maxAbs is zero, so are all elements, and result
+      // Maximum value is zero, and so will the result be.
+      // Avoid division by zero below.
       element = scale;
     } else {
       auto item{array_.At(at)};
       auto scaled{item.Divide(scale).value};
-      auto square{item.Multiply(scaled).value};
+      auto square{scaled.Multiply(scaled).value};
       auto next{square.Add(correction_, rounding_)};
       overflow_ |= next.flags.test(RealFlag::Overflow);
       auto sum{element.Add(next.value, rounding_)};
@@ -73,13 +86,16 @@ template <int KIND> class Norm2Accumulator {
   }
   bool overflow() const { return overflow_; }
   void Done(Scalar<T> &result) {
+    // result+correction == SUM((data(:)/maxAbs)**2)
+    // result = maxAbs * SQRT(result+correction)
     auto corrected{result.Add(correction_, rounding_)};
     overflow_ |= corrected.flags.test(RealFlag::Overflow);
     correction_ = Scalar<T>{};
-    auto rescaled{corrected.value.Multiply(maxAbs_.At(maxAbsAt_))};
+    auto root{corrected.value.SQRT().value};
+    auto product{root.Multiply(maxAbs_.At(maxAbsAt_))};
     maxAbs_.IncrementSubscripts(maxAbsAt_);
-    overflow_ |= rescaled.flags.test(RealFlag::Overflow);
-    result = rescaled.value.SQRT().value;
+    overflow_ |= product.flags.test(RealFlag::Overflow);
+    result = product.value;
   }
 
 private:
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index cff7f54c60d91ba..0dd55124e6a512e 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -228,7 +228,7 @@ template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
         test.Rewrite(context_, std::move(test)))};
     CHECK(folded.has_value());
     if (folded->IsTrue()) {
-      element = array_.At(at);
+      element = aAt;
     }
   }
   void Done(Scalar<T> &) const {}
diff --git a/flang/test/Evaluate/fold-norm2.f90 b/flang/test/Evaluate/fold-norm2.f90
index 30d5289b5a6e33c..370532bafaa13cf 100644
--- a/flang/test/Evaluate/fold-norm2.f90
+++ b/flang/test/Evaluate/fold-norm2.f90
@@ -17,13 +17,20 @@ module m
   real(dp), parameter :: a(3,4) = &
     reshape([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape(a))
   real(dp), parameter :: nAll = norm2(a)
-  real(dp), parameter :: check_nAll = sqrt(sum(a * a))
+  real(dp), parameter :: check_nAll = 11._dp * sqrt(sum((a/11._dp)**2))
   logical, parameter :: test_all = nAll == check_nAll
   real(dp), parameter :: norms1(4) = norm2(a, dim=1)
-  real(dp), parameter :: check_norms1(4) = sqrt(sum(a * a, dim=1))
+  real(dp), parameter :: check_norms1(4) = [ &
+    2.236067977499789805051477742381393909454345703125_8, &
+    7.07106781186547550532850436866283416748046875_8, &
+    1.2206555615733702069292121450416743755340576171875e1_8, &
+    1.7378147196982769884243680280633270740509033203125e1_8 ]
   logical, parameter :: test_norms1 = all(norms1 == check_norms1)
   real(dp), parameter :: norms2(3) = norm2(a, dim=2)
-  real(dp), parameter :: check_norms2(3) = sqrt(sum(a * a, dim=2))
+  real(dp), parameter :: check_norms2(3) = [ &
+    1.1224972160321822656214862945489585399627685546875e1_8, &
+    1.28840987267251261272349438513629138469696044921875e1_8, &
+    1.4628738838327791427218471653759479522705078125e1_8 ]
   logical, parameter :: test_norms2 = all(norms2 == check_norms2)
   logical, parameter :: test_normZ = norm2([0.,0.,0.]) == 0.
 end



More information about the flang-commits mailing list