[flang-commits] [flang] c9e9635 - [flang] evaluate: Fold SQRT, HYPOT, & CABS

peter klausler via flang-commits flang-commits at lists.llvm.org
Tue Sep 7 10:33:18 PDT 2021


Author: peter klausler
Date: 2021-09-07T10:33:11-07:00
New Revision: c9e9635ffef781c32a839a77d122d7930edfc9b2

URL: https://github.com/llvm/llvm-project/commit/c9e9635ffef781c32a839a77d122d7930edfc9b2
DIFF: https://github.com/llvm/llvm-project/commit/c9e9635ffef781c32a839a77d122d7930edfc9b2.diff

LOG: [flang] evaluate: Fold SQRT, HYPOT, & CABS

Implement IEEE Real::SQRT() operation, then use it to
also implement Real::HYPOT(), which can then be used directly
to implement Complex::ABS().

Differential Revision: https://reviews.llvm.org/D109250

Added: 
    flang/test/Evaluate/folding28.f90

Modified: 
    flang/include/flang/Evaluate/complex.h
    flang/include/flang/Evaluate/real.h
    flang/lib/Evaluate/fold-real.cpp
    flang/lib/Evaluate/intrinsics-library.cpp
    flang/lib/Evaluate/real.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Evaluate/complex.h b/flang/include/flang/Evaluate/complex.h
index 5b8861caf4bf8..2feb25bc353eb 100644
--- a/flang/include/flang/Evaluate/complex.h
+++ b/flang/include/flang/Evaluate/complex.h
@@ -77,6 +77,11 @@ template <typename REAL_TYPE> class Complex {
   ValueWithRealFlags<Complex> Divide(
       const Complex &, Rounding rounding = defaultRounding) const;
 
+  // ABS/CABS = HYPOT(re_, imag_) = SQRT(re_**2 + im_**2)
+  ValueWithRealFlags<Part> ABS(Rounding rounding = defaultRounding) const {
+    return re_.HYPOT(im_, rounding);
+  }
+
   constexpr Complex FlushSubnormalToZero() const {
     return {re_.FlushSubnormalToZero(), im_.FlushSubnormalToZero()};
   }
@@ -88,7 +93,6 @@ template <typename REAL_TYPE> class Complex {
   std::string DumpHexadecimal() const;
   llvm::raw_ostream &AsFortran(llvm::raw_ostream &, int kind) const;
 
-  // TODO: (C)ABS once Real::HYPOT is done
   // TODO: unit testing
 
 private:

diff  --git a/flang/include/flang/Evaluate/real.h b/flang/include/flang/Evaluate/real.h
index 9cd6f8305bd5c..fa637abe9c2e7 100644
--- a/flang/include/flang/Evaluate/real.h
+++ b/flang/include/flang/Evaluate/real.h
@@ -115,8 +115,10 @@ class Real : public common::RealDetails<PREC> {
   ValueWithRealFlags<Real> Divide(
       const Real &, Rounding rounding = defaultRounding) const;
 
-  // SQRT(x**2 + y**2) but computed so as to avoid spurious overflow
-  // TODO: not yet implemented; needed for CABS
+  ValueWithRealFlags<Real> SQRT(Rounding rounding = defaultRounding) const;
+
+  // HYPOT(x,y)=SQRT(x**2 + y**2) computed so as to avoid spurious
+  // intermediate overflows.
   ValueWithRealFlags<Real> HYPOT(
       const Real &, Rounding rounding = defaultRounding) const;
 

diff  --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 0ee465536a2a6..bffd8ea757555 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -27,8 +27,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
       name == "bessel_y1" || name == "cos" || name == "cosh" || name == "erf" ||
       name == "erfc" || name == "erfc_scaled" || name == "exp" ||
       name == "gamma" || name == "log" || name == "log10" ||
-      name == "log_gamma" || name == "sin" || name == "sinh" ||
-      name == "sqrt" || name == "tan" || name == "tanh") {
+      name == "log_gamma" || name == "sin" || name == "sinh" || name == "tan" ||
+      name == "tanh") {
     CHECK(args.size() == 1);
     if (auto callable{GetHostRuntimeWrapper<T, T>(name)}) {
       return FoldElementalIntrinsic<T, T>(
@@ -40,8 +40,7 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
   } else if (name == "amax0" || name == "amin0" || name == "amin1" ||
       name == "amax1" || name == "dmin1" || name == "dmax1") {
     return RewriteSpecificMINorMAX(context, std::move(funcRef));
-  } else if (name == "atan" || name == "atan2" || name == "hypot" ||
-      name == "mod") {
+  } else if (name == "atan" || name == "atan2" || name == "mod") {
     std::string localName{name == "atan" ? "atan2" : name};
     CHECK(args.size() == 2);
     if (auto callable{GetHostRuntimeWrapper<T, T, T>(localName)}) {
@@ -71,13 +70,10 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
       return FoldElementalIntrinsic<T, T>(
           context, std::move(funcRef), &Scalar<T>::ABS);
     } else if (auto *z{UnwrapExpr<Expr<SomeComplex>>(args[0])}) {
-      if (auto callable{GetHostRuntimeWrapper<T, ComplexT>("abs")}) {
-        return FoldElementalIntrinsic<T, ComplexT>(
-            context, std::move(funcRef), *callable);
-      } else {
-        context.messages().Say(
-            "abs(complex(kind=%d)) cannot be folded on host"_en_US, KIND);
-      }
+      return FoldElementalIntrinsic<T, ComplexT>(context, std::move(funcRef),
+          ScalarFunc<T, ComplexT>([](const Scalar<ComplexT> &z) -> Scalar<T> {
+            return z.ABS().value;
+          }));
     } else {
       common::die(" unexpected argument type inside abs");
     }
@@ -108,6 +104,13 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
     return Expr<T>{Scalar<T>::EPSILON()};
   } else if (name == "huge") {
     return Expr<T>{Scalar<T>::HUGE()};
+  } else if (name == "hypot") {
+    CHECK(args.size() == 2);
+    return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
+        ScalarFunc<T, T, T>(
+            [](const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> {
+              return x.HYPOT(y).value;
+            }));
   } else if (name == "max") {
     return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
   } else if (name == "maxval") {
@@ -130,6 +133,10 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
   } else if (name == "sign") {
     return FoldElementalIntrinsic<T, T, T>(
         context, std::move(funcRef), &Scalar<T>::SIGN);
+  } else if (name == "sqrt") {
+    return FoldElementalIntrinsic<T, T>(context, std::move(funcRef),
+        ScalarFunc<T, T>(
+            [](const Scalar<T> &x) -> Scalar<T> { return x.SQRT().value; }));
   } else if (name == "sum") {
     return FoldSum<T>(context, std::move(funcRef));
   } else if (name == "tiny") {

diff  --git a/flang/lib/Evaluate/intrinsics-library.cpp b/flang/lib/Evaluate/intrinsics-library.cpp
index 2aef9f701f0f7..e1e1c97c3f024 100644
--- a/flang/lib/Evaluate/intrinsics-library.cpp
+++ b/flang/lib/Evaluate/intrinsics-library.cpp
@@ -222,7 +222,6 @@ struct HostRuntimeLibrary<HostT, LibraryVersion::Libm> {
       FolderFactory<F, F{std::erfc}>::Create("erfc"),
       FolderFactory<F, F{std::exp}>::Create("exp"),
       FolderFactory<F, F{std::tgamma}>::Create("gamma"),
-      FolderFactory<F2, F2{std::hypot}>::Create("hypot"),
       FolderFactory<F, F{std::log}>::Create("log"),
       FolderFactory<F, F{std::log10}>::Create("log10"),
       FolderFactory<F, F{std::lgamma}>::Create("log_gamma"),
@@ -230,7 +229,6 @@ struct HostRuntimeLibrary<HostT, LibraryVersion::Libm> {
       FolderFactory<F2, F2{std::pow}>::Create("pow"),
       FolderFactory<F, F{std::sin}>::Create("sin"),
       FolderFactory<F, F{std::sinh}>::Create("sinh"),
-      FolderFactory<F, F{std::sqrt}>::Create("sqrt"),
       FolderFactory<F, F{std::tan}>::Create("tan"),
       FolderFactory<F, F{std::tanh}>::Create("tanh"),
   };

diff  --git a/flang/lib/Evaluate/real.cpp b/flang/lib/Evaluate/real.cpp
index 2146789049bea..e01e00b5ce504 100644
--- a/flang/lib/Evaluate/real.cpp
+++ b/flang/lib/Evaluate/real.cpp
@@ -261,6 +261,107 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::Divide(
   return result;
 }
 
+template <typename W, int P>
+ValueWithRealFlags<Real<W, P>> Real<W, P>::SQRT(Rounding rounding) const {
+  ValueWithRealFlags<Real> result;
+  if (IsNotANumber()) {
+    result.value = NotANumber();
+    if (IsSignalingNaN()) {
+      result.flags.set(RealFlag::InvalidArgument);
+    }
+  } else if (IsNegative()) {
+    if (IsZero()) {
+      // SQRT(-0) == -0 in IEEE-754.
+      result.value.word_ = result.value.word_.IBSET(bits - 1);
+    } else {
+      result.value = NotANumber();
+    }
+  } else if (IsInfinite()) {
+    // SQRT(+Inf) == +Inf
+    result.value = Infinity(false);
+  } else {
+    // Slow but reliable bit-at-a-time method.  Start with a clear significand
+    // and half the unbiased exponent, and then try to set significand bits
+    // in descending order of magnitude without exceeding the exact result.
+    int expo{UnbiasedExponent()};
+    if (IsSubnormal()) {
+      expo -= GetFraction().LEADZ();
+    }
+    expo = expo / 2 + exponentBias;
+    result.value.Normalize(false, expo, Fraction::MASKL(1));
+    for (int bit{significandBits - 1}; bit >= 0; --bit) {
+      Word word{result.value.word_};
+      result.value.word_ = word.IBSET(bit);
+      auto squared{result.value.Multiply(result.value, rounding)};
+      if (squared.flags.test(RealFlag::Overflow) ||
+          squared.flags.test(RealFlag::Underflow) ||
+          Compare(squared.value) == Relation::Less) {
+        result.value.word_ = word;
+      }
+    }
+    // The computed square root, when squared, has a square that's not greater
+    // than the original argument.  Check this square against the square of the
+    // next Real value, and return that one if its square is closer in magnitude
+    // to the original argument.
+    Real resultSq{result.value.Multiply(result.value).value};
+    Real 
diff {Subtract(resultSq).value.ABS()};
+    if (
diff .IsZero()) {
+      return result; // exact
+    }
+    Real ulp;
+    ulp.Normalize(false, expo, Fraction::MASKR(1));
+    Real nextAfter{result.value.Add(ulp).value};
+    auto nextAfterSq{nextAfter.Multiply(nextAfter)};
+    if (!nextAfterSq.flags.test(RealFlag::Overflow) &&
+        !nextAfterSq.flags.test(RealFlag::Underflow)) {
+      Real nextAfterDiff{Subtract(nextAfterSq.value).value.ABS()};
+      if (nextAfterDiff.Compare(
diff ) == Relation::Less) {
+        result.value = nextAfter;
+        if (nextAfterDiff.IsZero()) {
+          return result; // exact
+        }
+      }
+    }
+    result.flags.set(RealFlag::Inexact);
+  }
+  return result;
+}
+
+// HYPOT(x,y) = SQRT(x**2 + y**2) by definition, but those squared intermediate
+// values are susceptible to over/underflow when computed naively.
+// Assuming that x>=y, calculate instead:
+//   HYPOT(x,y) = SQRT(x**2 * (1+(y/x)**2))
+//              = ABS(x) * SQRT(1+(y/x)**2)
+template <typename W, int P>
+ValueWithRealFlags<Real<W, P>> Real<W, P>::HYPOT(
+    const Real &y, Rounding rounding) const {
+  ValueWithRealFlags<Real> result;
+  if (IsNotANumber() || y.IsNotANumber()) {
+    result.flags.set(RealFlag::InvalidArgument);
+    result.value = NotANumber();
+  } else if (ABS().Compare(y.ABS()) == Relation::Less) {
+    return y.HYPOT(*this);
+  } else if (IsZero()) {
+    return result; // x==y==0
+  } else {
+    auto yOverX{y.Divide(*this, rounding)}; // y/x
+    bool inexact{yOverX.flags.test(RealFlag::Inexact)};
+    auto squared{yOverX.value.Multiply(yOverX.value, rounding)}; // (y/x)**2
+    inexact |= squared.flags.test(RealFlag::Inexact);
+    Real one;
+    one.Normalize(false, exponentBias, Fraction::MASKL(1)); // 1.0
+    auto sum{squared.value.Add(one, rounding)}; // 1.0 + (y/x)**2
+    inexact |= sum.flags.test(RealFlag::Inexact);
+    auto sqrt{sum.value.SQRT()};
+    inexact |= sqrt.flags.test(RealFlag::Inexact);
+    result = sqrt.value.Multiply(ABS(), rounding);
+    if (inexact) {
+      result.flags.set(RealFlag::Inexact);
+    }
+  }
+  return result;
+}
+
 template <typename W, int P>
 ValueWithRealFlags<Real<W, P>> Real<W, P>::ToWholeNumber(
     common::RoundingMode mode) const {

diff  --git a/flang/test/Evaluate/folding28.f90 b/flang/test/Evaluate/folding28.f90
new file mode 100644
index 0000000000000..406fc06afd381
--- /dev/null
+++ b/flang/test/Evaluate/folding28.f90
@@ -0,0 +1,40 @@
+! RUN: %S/test_folding.sh %s %t %flang_fc1
+! REQUIRES: shell
+! Tests folding of SQRT()
+module m
+  implicit none
+  ! +Inf
+  real(8), parameter :: inf8 = z'7ff0000000000000'
+  logical, parameter :: test_inf8 = sqrt(inf8) == inf8
+  ! max finite
+  real(8), parameter :: h8 = huge(1.0_8), h8z = z'7fefffffffffffff'
+  logical, parameter :: test_h8 = h8 == h8z
+  real(8), parameter :: sqrt_h8 = sqrt(h8), sqrt_h8z = z'5fefffffffffffff'
+  logical, parameter :: test_sqrt_h8 = sqrt_h8 == sqrt_h8z
+  real(8), parameter :: sqr_sqrt_h8 = sqrt_h8 * sqrt_h8, sqr_sqrt_h8z = z'7feffffffffffffe'
+  logical, parameter :: test_sqr_sqrt_h8 = sqr_sqrt_h8 == sqr_sqrt_h8z
+  ! -0 (sqrt is -0)
+  real(8), parameter :: n08 = z'8000000000000000'
+  real(8), parameter :: sqrt_n08 = sqrt(n08)
+!WARN: division by zero
+  real(8), parameter :: inf_n08 = 1.0_8 / sqrt_n08, inf_n08z = z'fff0000000000000'
+  logical, parameter :: test_n08 = inf_n08 == inf_n08z
+  ! min normal
+  real(8), parameter :: t8 = tiny(1.0_8), t8z = z'0010000000000000'
+  logical, parameter :: test_t8 = t8 == t8z
+  real(8), parameter :: sqrt_t8 = sqrt(t8), sqrt_t8z = z'2000000000000000'
+  logical, parameter :: test_sqrt_t8 = sqrt_t8 == sqrt_t8z
+  real(8), parameter :: sqr_sqrt_t8 = sqrt_t8 * sqrt_t8
+  logical, parameter :: test_sqr_sqrt_t8 = sqr_sqrt_t8 == t8
+  ! max subnormal
+  real(8), parameter :: maxs8 = z'000fffffffffffff'
+  real(8), parameter :: sqrt_maxs8 = sqrt(maxs8), sqrt_maxs8z = z'2000000000000000'
+  logical, parameter :: test_sqrt_maxs8 = sqrt_maxs8 == sqrt_maxs8z
+  ! min subnormal
+  real(8), parameter :: mins8 = z'1'
+  real(8), parameter :: sqrt_mins8 = sqrt(mins8), sqrt_mins8z = z'1e60000000000000'
+  logical, parameter :: test_sqrt_mins8 = sqrt_mins8 == sqrt_mins8z
+  real(8), parameter :: sqr_sqrt_mins8 = sqrt_mins8 * sqrt_mins8
+  logical, parameter :: test_sqr_sqrt_mins8 = sqr_sqrt_mins8 == mins8
+end module
+


        


More information about the flang-commits mailing list