[flang-commits] [flang] [flang] Avoid needless overflow when folding NORM2 (PR #67499)
Peter Klausler via flang-commits
flang-commits at lists.llvm.org
Fri Sep 29 14:19:13 PDT 2023
https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/67499
>From e1b83ed7afe3ee733a7c10b9347c6c94450b337b Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Tue, 26 Sep 2023 15:34:03 -0700
Subject: [PATCH] [flang] Avoid needless overflow when folding NORM2
The code that folds the relatively new NORM2 intrinsic
function can produce overflow in cases where it's not
warranted. Rearrange to NORM2 = M * SQRT((A(:)/M)**2)
where M is MAXVAL(ABS(A)).
---
flang/lib/Evaluate/fold-real.cpp | 28 ++++++++++++++++++++++------
flang/lib/Evaluate/fold-reduction.h | 2 +-
flang/test/Evaluate/fold-norm2.f90 | 13 ++++++++++---
3 files changed, 33 insertions(+), 10 deletions(-)
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 8e3ab1d8fd30b09..6bcc3ec73982157 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -52,15 +52,28 @@ template <int KIND> class Norm2Accumulator {
const Constant<T> &array, const Constant<T> &maxAbs, Rounding rounding)
: array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
- // Kahan summation of scaled elements
+ // Kahan summation of scaled elements:
+ // Naively,
+ // NORM2(A(:)) = SQRT(SUM(A(:)**2))
+ // For any T > 0, we have mathematically
+ // SQRT(SUM(A(:)**2))
+ // = SQRT(T**2 * (SUM(A(:)**2) / T**2))
+ // = SQRT(T**2 * SUM(A(:)**2 / T**2))
+ // = SQRT(T**2 * SUM((A(:)/T)**2))
+ // = SQRT(T**2) * SQRT(SUM((A(:)/T)**2))
+ // = T * SQRT(SUM((A(:)/T)**2))
+ // By letting T = MAXVAL(ABS(A)), we ensure that
+ // ALL(ABS(A(:)/T) <= 1), so ALL((A(:)/T)**2 <= 1), and the SUM will
+ // not overflow unless absolutely necessary.
auto scale{maxAbs_.At(maxAbsAt_)};
if (scale.IsZero()) {
- // If maxAbs is zero, so are all elements, and result
+ // Maximum value is zero, and so will the result be.
+ // Avoid division by zero below.
element = scale;
} else {
auto item{array_.At(at)};
auto scaled{item.Divide(scale).value};
- auto square{item.Multiply(scaled).value};
+ auto square{scaled.Multiply(scaled).value};
auto next{square.Add(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
@@ -73,13 +86,16 @@ template <int KIND> class Norm2Accumulator {
}
bool overflow() const { return overflow_; }
void Done(Scalar<T> &result) {
+ // result+correction == SUM((data(:)/maxAbs)**2)
+ // result = maxAbs * SQRT(result+correction)
auto corrected{result.Add(correction_, rounding_)};
overflow_ |= corrected.flags.test(RealFlag::Overflow);
correction_ = Scalar<T>{};
- auto rescaled{corrected.value.Multiply(maxAbs_.At(maxAbsAt_))};
+ auto root{corrected.value.SQRT().value};
+ auto product{root.Multiply(maxAbs_.At(maxAbsAt_))};
maxAbs_.IncrementSubscripts(maxAbsAt_);
- overflow_ |= rescaled.flags.test(RealFlag::Overflow);
- result = rescaled.value.SQRT().value;
+ overflow_ |= product.flags.test(RealFlag::Overflow);
+ result = product.value;
}
private:
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index cff7f54c60d91ba..0dd55124e6a512e 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -228,7 +228,7 @@ template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
test.Rewrite(context_, std::move(test)))};
CHECK(folded.has_value());
if (folded->IsTrue()) {
- element = array_.At(at);
+ element = aAt;
}
}
void Done(Scalar<T> &) const {}
diff --git a/flang/test/Evaluate/fold-norm2.f90 b/flang/test/Evaluate/fold-norm2.f90
index 30d5289b5a6e33c..370532bafaa13cf 100644
--- a/flang/test/Evaluate/fold-norm2.f90
+++ b/flang/test/Evaluate/fold-norm2.f90
@@ -17,13 +17,20 @@ module m
real(dp), parameter :: a(3,4) = &
reshape([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape(a))
real(dp), parameter :: nAll = norm2(a)
- real(dp), parameter :: check_nAll = sqrt(sum(a * a))
+ real(dp), parameter :: check_nAll = 11._dp * sqrt(sum((a/11._dp)**2))
logical, parameter :: test_all = nAll == check_nAll
real(dp), parameter :: norms1(4) = norm2(a, dim=1)
- real(dp), parameter :: check_norms1(4) = sqrt(sum(a * a, dim=1))
+ real(dp), parameter :: check_norms1(4) = [ &
+ 2.236067977499789805051477742381393909454345703125_8, &
+ 7.07106781186547550532850436866283416748046875_8, &
+ 1.2206555615733702069292121450416743755340576171875e1_8, &
+ 1.7378147196982769884243680280633270740509033203125e1_8 ]
logical, parameter :: test_norms1 = all(norms1 == check_norms1)
real(dp), parameter :: norms2(3) = norm2(a, dim=2)
- real(dp), parameter :: check_norms2(3) = sqrt(sum(a * a, dim=2))
+ real(dp), parameter :: check_norms2(3) = [ &
+ 1.1224972160321822656214862945489585399627685546875e1_8, &
+ 1.28840987267251261272349438513629138469696044921875e1_8, &
+ 1.4628738838327791427218471653759479522705078125e1_8 ]
logical, parameter :: test_norms2 = all(norms2 == check_norms2)
logical, parameter :: test_normZ = norm2([0.,0.,0.]) == 0.
end
More information about the flang-commits
mailing list