[flang-commits] [flang] 1ef5e6d - [flang] Make SQRT folding exact

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Thu Jun 23 11:22:04 PDT 2022


Author: Peter Klausler
Date: 2022-06-23T11:16:39-07:00
New Revision: 1ef5e6de7605434cb6316847a46eb01c23c70715

URL: https://github.com/llvm/llvm-project/commit/1ef5e6de7605434cb6316847a46eb01c23c70715
DIFF: https://github.com/llvm/llvm-project/commit/1ef5e6de7605434cb6316847a46eb01c23c70715.diff

LOG: [flang] Make SQRT folding exact

Replace the latter half of the SQRT() folding algorithm with code that
calculates an exact root with extra rounding bits, and then lets the
usual normalization and rounding code do the right thing.  Extend
tests to catch regressions.

Differential Revision: https://reviews.llvm.org/D128395

Added: 
    

Modified: 
    flang/lib/Evaluate/real.cpp
    flang/test/Evaluate/folding28.f90
    flang/unittests/Evaluate/real.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Evaluate/real.cpp b/flang/lib/Evaluate/real.cpp
index 3b81d23afe41c..b7230f891fa80 100644
--- a/flang/lib/Evaluate/real.cpp
+++ b/flang/lib/Evaluate/real.cpp
@@ -274,6 +274,7 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::SQRT(Rounding rounding) const {
       // SQRT(-0) == -0 in IEEE-754.
       result.value = NegativeZero();
     } else {
+      result.flags.set(RealFlag::InvalidArgument);
       result.value = NotANumber();
     }
   } else if (IsInfinite()) {
@@ -297,53 +298,31 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::SQRT(Rounding rounding) const {
           result.value.GetFraction());
       return result;
     }
-    // Compute the square root of the reduced value with the slow but
-    // reliable bit-at-a-time method.  Start with a clear significand and
-    // half of the unbiased exponent, and then try to set significand bits
-    // in descending order of magnitude without exceeding the exact result.
-    expo = expo / 2 + exponentBias;
-    result.value.Normalize(false, expo, Fraction::MASKL(1));
-    Real initialSq{result.value.Multiply(result.value).value};
-    if (Compare(initialSq) == Relation::Less) {
-      // Initial estimate is too large; this can happen for values just
-      // under 1.0.
-      --expo;
-      result.value.Normalize(false, expo, Fraction::MASKL(1));
-    }
-    for (int bit{significandBits - 1}; bit >= 0; --bit) {
-      Word word{result.value.word_};
-      result.value.word_ = word.IBSET(bit);
-      auto squared{result.value.Multiply(result.value, rounding)};
-      if (squared.flags.test(RealFlag::Overflow) ||
-          squared.flags.test(RealFlag::Underflow) ||
-          Compare(squared.value) == Relation::Less) {
-        result.value.word_ = word;
-      }
-    }
-    // The computed square root has a square that's not greater than the
-    // original argument.  Check this square against the square of the next
-    // larger Real and return that one if its square is closer in magnitude to
-    // the original argument.
-    Real resultSq{result.value.Multiply(result.value).value};
-    Real 
diff {Subtract(resultSq).value.ABS()};
-    if (
diff .IsZero()) {
-      return result; // exact
-    }
-    Real ulp;
-    ulp.Normalize(false, expo, Fraction::MASKR(1));
-    Real nextAfter{result.value.Add(ulp).value};
-    auto nextAfterSq{nextAfter.Multiply(nextAfter)};
-    if (!nextAfterSq.flags.test(RealFlag::Overflow) &&
-        !nextAfterSq.flags.test(RealFlag::Underflow)) {
-      Real nextAfterDiff{Subtract(nextAfterSq.value).value.ABS()};
-      if (nextAfterDiff.Compare(
diff ) == Relation::Less) {
-        result.value = nextAfter;
-        if (nextAfterDiff.IsZero()) {
-          return result; // exact
-        }
+    // (-1) <= expo <= 1; use it as a shift to set the desired square.
+    using Extended = typename value::Integer<(binaryPrecision + 2)>;
+    Extended goal{
+        Extended::ConvertUnsigned(GetFraction()).value.SHIFTL(expo + 1)};
+    // Calculate the exact square root by maximizing a value whose square
+    // does not exceed the goal.  Use two extra bits of precision for
+    // rounding.
+    bool sticky{true};
+    Extended extFrac{};
+    for (int bit{Extended::bits - 1}; bit >= 0; --bit) {
+      Extended next{extFrac.IBSET(bit)};
+      auto squared{next.MultiplyUnsigned(next)};
+      auto cmp{squared.upper.CompareUnsigned(goal)};
+      if (cmp == Ordering::Less) {
+        extFrac = next;
+      } else if (cmp == Ordering::Equal && squared.lower.IsZero()) {
+        extFrac = next;
+        sticky = false;
+        break; // exact result
       }
     }
-    result.flags.set(RealFlag::Inexact);
+    RoundingBits roundingBits{extFrac.BTEST(1), extFrac.BTEST(0), sticky};
+    NormalizeAndRound(result, false, exponentBias,
+        Fraction::ConvertUnsigned(extFrac.SHIFTR(2)).value, rounding,
+        roundingBits);
   }
   return result;
 }

diff  --git a/flang/test/Evaluate/folding28.f90 b/flang/test/Evaluate/folding28.f90
index 004c661692fee..642919de7414a 100644
--- a/flang/test/Evaluate/folding28.f90
+++ b/flang/test/Evaluate/folding28.f90
@@ -49,4 +49,25 @@ module m
   logical, parameter :: test_sqrt_zero_4 = sqrt_zero_4 == 0.0
   real(8), parameter :: sqrt_zero_8 = sqrt(0.0)
   logical, parameter :: test_sqrt_zero_8 = sqrt_zero_8 == 0.0
+  ! Some common values to get right
+  real(8), parameter :: sqrt_1_8 = sqrt(1.d0)
+  logical, parameter :: test_sqrt_1_8 = sqrt_1_8 == 1.d0
+  real(8), parameter :: sqrt_2_8 = sqrt(2.d0)
+  logical, parameter :: test_sqrt_2_8 = sqrt_2_8 == 1.4142135623730951454746218587388284504413604736328125d0
+  real(8), parameter :: sqrt_3_8 = sqrt(3.d0)
+  logical, parameter :: test_sqrt_3_8 = sqrt_3_8 == 1.732050807568877193176604123436845839023590087890625d0
+  real(8), parameter :: sqrt_4_8 = sqrt(4.d0)
+  logical, parameter :: test_sqrt_4_8 = sqrt_4_8 == 2.d0
+  real(8), parameter :: sqrt_5_8 = sqrt(5.d0)
+  logical, parameter :: test_sqrt_5_8 = sqrt_5_8 == 2.236067977499789805051477742381393909454345703125d0
+  real(8), parameter :: sqrt_6_8 = sqrt(6.d0)
+  logical, parameter :: test_sqrt_6_8 = sqrt_6_8 == 2.44948974278317788133563226438127458095550537109375d0
+  real(8), parameter :: sqrt_7_8 = sqrt(7.d0)
+  logical, parameter :: test_sqrt_7_8 = sqrt_7_8 == 2.64575131106459071617109657381661236286163330078125d0
+  real(8), parameter :: sqrt_8_8 = sqrt(8.d0)
+  logical, parameter :: test_sqrt_8_8 = sqrt_8_8 == 2.828427124746190290949243717477656900882720947265625d0
+  real(8), parameter :: sqrt_9_8 = sqrt(9.d0)
+  logical, parameter :: test_sqrt_9_8 = sqrt_9_8 == 3.d0
+  real(8), parameter :: sqrt_10_8 = sqrt(10.d0)
+  logical, parameter :: test_sqrt_10_8 = sqrt_10_8 == 3.162277660168379522787063251598738133907318115234375d0
 end module

diff  --git a/flang/unittests/Evaluate/real.cpp b/flang/unittests/Evaluate/real.cpp
index 1974f42624415..60e5710b52a43 100644
--- a/flang/unittests/Evaluate/real.cpp
+++ b/flang/unittests/Evaluate/real.cpp
@@ -392,6 +392,22 @@ void subsetTests(int pass, Rounding rounding, std::uint32_t opds) {
       ("%d AINT(0x%jx)", pass, static_cast<std::intmax_t>(rj));
     }
 
+    {
+      ValueWithRealFlags<REAL> root{x.SQRT(rounding)};
+#ifndef __clang__ // broken and also slow
+      fpenv.ClearFlags();
+#endif
+      FLT fcheck{std::sqrt(fj)};
+      auto actualFlags{FlagsToBits(fpenv.CurrentFlags())};
+      u.f = fcheck;
+      UINT rcheck{NormalizeNaN(u.ui)};
+      UINT check = root.value.RawBits().ToUInt64();
+      MATCH(rcheck, check)
+      ("%d SQRT(0x%jx)", pass, static_cast<std::intmax_t>(rj));
+      MATCH(actualFlags, FlagsToBits(root.flags))
+      ("%d SQRT(0x%jx)", pass, static_cast<std::intmax_t>(rj));
+    }
+
     {
       MATCH(IsNaN(rj), x.IsNotANumber())
       ("%d IsNaN(0x%jx)", pass, static_cast<std::intmax_t>(rj));


        


More information about the flang-commits mailing list