[libc-commits] [libc] [llvm] [libc][math] Fix signaling nan handling of hypot(f) and improve hypotf performance. (PR #99432)

via libc-commits libc-commits at lists.llvm.org
Thu Jul 18 16:32:31 PDT 2024


https://github.com/lntue updated https://github.com/llvm/llvm-project/pull/99432

>From 4842832c0303cd4ef222503849f38c6ea7711b83 Mon Sep 17 00:00:00 2001
From: Tue Ly <lntue.h at gmail.com>
Date: Thu, 18 Jul 2024 05:03:38 +0000
Subject: [PATCH 1/2] [libc][math] Fix signaling nan handling of hypot(f) and
 improve hypot's performance.

---
 libc/src/__support/FPUtil/Hypot.h             | 63 ++++++------
 libc/src/math/generic/CMakeLists.txt          |  5 +-
 libc/src/math/generic/hypotf.cpp              | 98 ++++++++++++-------
 libc/test/src/math/smoke/HypotTest.h          | 27 ++---
 .../llvm-project-overlay/libc/BUILD.bazel     |  1 +
 5 files changed, 104 insertions(+), 90 deletions(-)

diff --git a/libc/src/__support/FPUtil/Hypot.h b/libc/src/__support/FPUtil/Hypot.h
index a5a991e75ed01..6aa808446d6d9 100644
--- a/libc/src/__support/FPUtil/Hypot.h
+++ b/libc/src/__support/FPUtil/Hypot.h
@@ -109,45 +109,39 @@ LIBC_INLINE T hypot(T x, T y) {
   using StorageType = typename FPBits<T>::StorageType;
   using DStorageType = typename DoubleLength<StorageType>::Type;
 
-  FPBits_t x_bits(x), y_bits(y);
+  FPBits_t x_abs = FPBits_t(x).abs();
+  FPBits_t y_abs = FPBits_t(y).abs();
 
-  if (x_bits.is_inf() || y_bits.is_inf()) {
-    return FPBits_t::inf().get_val();
-  }
-  if (x_bits.is_nan()) {
-    return x;
-  }
-  if (y_bits.is_nan()) {
+  bool x_abs_larger = x_abs.uintval() >= y_abs.uintval();
+
+  FPBits_t a_bits = x_abs_larger ? x_abs : y_abs;
+  FPBits_t b_bits = x_abs_larger ? y_abs : x_abs;
+
+  if (LIBC_UNLIKELY(a_bits.is_inf_or_nan())) {
+    if (x_abs.is_signaling_nan() || y_abs.is_signaling_nan()) {
+      fputil::raise_except_if_required(FE_INVALID);
+      return FPBits_t::quiet_nan().get_val();
+    }
+    if (x_abs.is_inf() || y_abs.is_inf())
+      return FPBits_t::inf().get_val();
+    if (x_abs.is_nan())
+      return x;
+    // y is nan
     return y;
   }
 
-  uint16_t x_exp = x_bits.get_biased_exponent();
-  uint16_t y_exp = y_bits.get_biased_exponent();
-  uint16_t exp_diff = (x_exp > y_exp) ? (x_exp - y_exp) : (y_exp - x_exp);
+  uint16_t a_exp = a_bits.get_biased_exponent();
+  uint16_t b_exp = b_bits.get_biased_exponent();
 
-  if ((exp_diff >= FPBits_t::FRACTION_LEN + 2) || (x == 0) || (y == 0)) {
-    return abs(x) + abs(y);
-  }
+  if ((a_exp - b_exp >= FPBits_t::FRACTION_LEN + 2) || (x == 0) || (y == 0))
+    return x_abs.get_val() + y_abs.get_val();
 
-  uint16_t a_exp, b_exp, out_exp;
-  StorageType a_mant, b_mant;
+  uint64_t out_exp = a_exp;
+  StorageType a_mant = a_bits.get_mantissa();
+  StorageType b_mant = b_bits.get_mantissa();
   DStorageType a_mant_sq, b_mant_sq;
   bool sticky_bits;
 
-  if (abs(x) >= abs(y)) {
-    a_exp = x_exp;
-    a_mant = x_bits.get_mantissa();
-    b_exp = y_exp;
-    b_mant = y_bits.get_mantissa();
-  } else {
-    a_exp = y_exp;
-    a_mant = y_bits.get_mantissa();
-    b_exp = x_exp;
-    b_mant = x_bits.get_mantissa();
-  }
-
-  out_exp = a_exp;
-
   // Add an extra bit to simplify the final rounding bit computation.
   constexpr StorageType ONE = StorageType(1) << (FPBits_t::FRACTION_LEN + 1);
 
@@ -165,11 +159,10 @@ LIBC_INLINE T hypot(T x, T y) {
     a_exp = 1;
   }
 
-  if (b_exp != 0) {
+  if (b_exp != 0)
     b_mant |= ONE;
-  } else {
+  else
     b_exp = 1;
-  }
 
   a_mant_sq = static_cast<DStorageType>(a_mant) * a_mant;
   b_mant_sq = static_cast<DStorageType>(b_mant) * b_mant;
@@ -260,6 +253,10 @@ LIBC_INLINE T hypot(T x, T y) {
   }
 
   y_new |= static_cast<StorageType>(out_exp) << FPBits_t::FRACTION_LEN;
+
+  if (!(round_bit || sticky_bits || (r != 0)))
+    fputil::clear_except_if_required(FE_INEXACT);
+
   return cpp::bit_cast<T>(y_new);
 }
 
diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt
index 415ca3fbce796..413ece04e9421 100644
--- a/libc/src/math/generic/CMakeLists.txt
+++ b/libc/src/math/generic/CMakeLists.txt
@@ -2809,9 +2809,12 @@ add_entrypoint_object(
   HDRS
     ../hypotf.h
   DEPENDS
-    libc.src.__support.FPUtil.basic_operations
+    libc.src.__support.FPUtil.double_double
+    libc.src.__support.FPUtil.fenv_impl
     libc.src.__support.FPUtil.fp_bits
+    libc.src.__support.FPUtil.multiply_add
     libc.src.__support.FPUtil.sqrt
+    libc.src.__support.macros.optimization
   COMPILE_OPTIONS
     -O3
 )
diff --git a/libc/src/math/generic/hypotf.cpp b/libc/src/math/generic/hypotf.cpp
index 75c55ed846726..478209a14f454 100644
--- a/libc/src/math/generic/hypotf.cpp
+++ b/libc/src/math/generic/hypotf.cpp
@@ -6,11 +6,14 @@
 //
 //===----------------------------------------------------------------------===//
 #include "src/math/hypotf.h"
-#include "src/__support/FPUtil/BasicOperations.h"
+#include "src/__support/FPUtil/FEnvImpl.h"
 #include "src/__support/FPUtil/FPBits.h"
+#include "src/__support/FPUtil/double_double.h"
+#include "src/__support/FPUtil/multiply_add.h"
 #include "src/__support/FPUtil/sqrt.h"
 #include "src/__support/common.h"
 #include "src/__support/macros/config.h"
+#include "src/__support/macros/optimization.h"
 
 namespace LIBC_NAMESPACE_DECL {
 
@@ -18,54 +21,73 @@ LLVM_LIBC_FUNCTION(float, hypotf, (float x, float y)) {
   using DoubleBits = fputil::FPBits<double>;
   using FPBits = fputil::FPBits<float>;
 
-  FPBits x_bits(x), y_bits(y);
+  FPBits x_abs = FPBits(x).abs();
+  FPBits y_abs = FPBits(y).abs();
 
-  uint16_t x_exp = x_bits.get_biased_exponent();
-  uint16_t y_exp = y_bits.get_biased_exponent();
-  uint16_t exp_diff = (x_exp > y_exp) ? (x_exp - y_exp) : (y_exp - x_exp);
+  bool x_abs_larger = x_abs.uintval() >= y_abs.uintval();
 
-  if (exp_diff >= FPBits::FRACTION_LEN + 2) {
-    return fputil::abs(x) + fputil::abs(y);
-  }
+  FPBits a_bits = x_abs_larger ? x_abs : y_abs;
+  FPBits b_bits = x_abs_larger ? y_abs : x_abs;
 
-  double xd = static_cast<double>(x);
-  double yd = static_cast<double>(y);
+  uint32_t a_u = a_bits.uintval();
+  uint32_t b_u = b_bits.uintval();
 
-  // These squares are exact.
-  double x_sq = xd * xd;
-  double y_sq = yd * yd;
+  if (LIBC_UNLIKELY(a_u >= FPBits::EXP_MASK)) {
+    // x or y is inf or nan
+    if (a_bits.is_signaling_nan() || b_bits.is_signaling_nan()) {
+      fputil::raise_except_if_required(FE_INVALID);
+      return FPBits::quiet_nan().get_val();
+    }
+    if (a_bits.is_inf() || b_bits.is_inf())
+      return FPBits::inf().get_val();
+    return a_bits.get_val();
+  }
 
-  // Compute the sum of squares.
-  double sum_sq = x_sq + y_sq;
+  if (LIBC_UNLIKELY(a_u - b_u >=
+                    static_cast<uint32_t>((FPBits::FRACTION_LEN + 2)
+                                          << FPBits::FRACTION_LEN)))
+    return x_abs.get_val() + y_abs.get_val();
 
-  // Compute the rounding error with Fast2Sum algorithm:
-  // x_sq + y_sq = sum_sq - err
-  double err = (x_sq >= y_sq) ? (sum_sq - x_sq) - y_sq : (sum_sq - y_sq) - x_sq;
+  double ad = static_cast<double>(a_bits.get_val());
+  double bd = static_cast<double>(b_bits.get_val());
+
+  // These squares are exact.
+  double a_sq = ad * ad;
+#ifdef LIBC_TARGET_CPU_HAS_FMA
+  double sum_sq = fputil::multiply_add(bd, bd, a_sq);
+#else
+  double b_sq = bd * bd;
+  double sum_sq = a_sq + b_sq;
+#endif
 
   // Take sqrt in double precision.
   DoubleBits result(fputil::sqrt<double>(sum_sq));
+  uint64_t r_u = result.uintval();
 
-  if (!DoubleBits(sum_sq).is_inf_or_nan()) {
-    // Correct rounding.
-    double r_sq = result.get_val() * result.get_val();
-    double diff = sum_sq - r_sq;
-    constexpr uint64_t MASK = 0x0000'0000'3FFF'FFFFULL;
-    uint64_t lrs = result.uintval() & MASK;
-
-    if (lrs == 0x0000'0000'1000'0000ULL && err < diff) {
-      result.set_uintval(result.uintval() | 1ULL);
-    } else if (lrs == 0x0000'0000'3000'0000ULL && err > diff) {
-      result.set_uintval(result.uintval() - 1ULL);
-    }
-  } else {
-    FPBits bits_x(x), bits_y(y);
-    if (bits_x.is_inf_or_nan() || bits_y.is_inf_or_nan()) {
-      if (bits_x.is_inf() || bits_y.is_inf())
-        return FPBits::inf().get_val();
-      if (bits_x.is_nan())
-        return x;
-      return y;
+  // If any of the sticky bits of the result are non-zero, except the LSB, then
+  // the rounded result is correct.
+  if (LIBC_UNLIKELY(((r_u + 1) & 0x0000'0000'0FFF'FFFE) == 0)) {
+    double r_d = result.get_val();
+
+    // Perform rounding correction.
+#ifdef LIBC_TARGET_CPU_HAS_FMA
+    double sum_sq_lo = fputil::multiply_add(bd, bd, a_sq - sum_sq);
+    double err = sum_sq_lo - fputil::multiply_add(r_d, r_d, -sum_sq);
+#else
+    fputil::DoubleDouble r_sq = fputil::exact_mult(r_d, r_d);
+    double sum_sq_lo = b_sq - (sum_sq - a_sq);
+    double err = (sum_sq - r_sq.hi) + (sum_sq_lo - r_sq.lo);
+#endif
+
+    if (err > 0)
+      r_u |= 1;
+    else if ((err < 0) && (r_u & 1) == 0)
+      r_u -= 1;
+    else if ((r_u & 0x0000'0000'1FFF'FFFF) == 0) {
+      // The rounded result is exact.
+      fputil::clear_except_if_required(FE_INEXACT);
     }
+    return static_cast<float>(DoubleBits(r_u).get_val());
   }
 
   return static_cast<float>(result.get_val());
diff --git a/libc/test/src/math/smoke/HypotTest.h b/libc/test/src/math/smoke/HypotTest.h
index 80e9bb7366dfe..6efc3c8509c7f 100644
--- a/libc/test/src/math/smoke/HypotTest.h
+++ b/libc/test/src/math/smoke/HypotTest.h
@@ -17,22 +17,11 @@
 #include "hdr/math_macros.h"
 
 template <typename T>
-class HypotTestTemplate : public LIBC_NAMESPACE::testing::FEnvSafeTest {
+class HypotTestTemplate : public LIBC_NAMESPACE::testing::Test {
 private:
   using Func = T (*)(T, T);
-  using FPBits = LIBC_NAMESPACE::fputil::FPBits<T>;
-  using StorageType = typename FPBits::StorageType;
 
-  const T nan = FPBits::quiet_nan().get_val();
-  const T inf = FPBits::inf(Sign::POS).get_val();
-  const T neg_inf = FPBits::inf(Sign::NEG).get_val();
-  const T zero = FPBits::zero(Sign::POS).get_val();
-  const T neg_zero = FPBits::zero(Sign::NEG).get_val();
-
-  const T max_normal = FPBits::max_normal().get_val();
-  const T min_normal = FPBits::min_normal().get_val();
-  const T max_subnormal = FPBits::max_subnormal().get_val();
-  const T min_subnormal = FPBits::min_subnormal().get_val();
+  DECLARE_SPECIAL_CONSTANTS(T)
 
 public:
   void test_special_numbers(Func func) {
@@ -40,11 +29,13 @@ class HypotTestTemplate : public LIBC_NAMESPACE::testing::FEnvSafeTest {
     // Pythagorean triples.
     constexpr T PYT[N][3] = {{3, 4, 5}, {5, 12, 13}, {8, 15, 17}, {7, 24, 25}};
 
-    EXPECT_FP_EQ(func(inf, nan), inf);
-    EXPECT_FP_EQ(func(nan, neg_inf), inf);
-    EXPECT_FP_EQ(func(nan, nan), nan);
-    EXPECT_FP_EQ(func(nan, zero), nan);
-    EXPECT_FP_EQ(func(neg_zero, nan), nan);
+    EXPECT_FP_EQ(func(inf, sNaN), aNaN);
+    EXPECT_FP_EQ(func(sNaN, neg_inf), aNaN);
+    EXPECT_FP_EQ(func(inf, aNaN), inf);
+    EXPECT_FP_EQ(func(aNaN, neg_inf), inf);
+    EXPECT_FP_EQ(func(aNaN, aNaN), aNaN);
+    EXPECT_FP_EQ(func(aNaN, zero), aNaN);
+    EXPECT_FP_EQ(func(neg_zero, aNaN), aNaN);
 
     for (int i = 0; i < N; ++i) {
       EXPECT_FP_EQ_ALL_ROUNDING(PYT[i][2], func(PYT[i][0], PYT[i][1]));
diff --git a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
index f0c25e658fd18..63fd3faf38aef 100644
--- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
@@ -2020,6 +2020,7 @@ libc_math_function(name = "hypot")
 libc_math_function(
     name = "hypotf",
     additional_deps = [
+        ":__support_fputil_double_double",
         ":__support_fputil_sqrt",
     ],
 )

>From ee3bae932b77c993d00c1842cb61720de0e7b847 Mon Sep 17 00:00:00 2001
From: Tue Ly <lntue.h at gmail.com>
Date: Thu, 18 Jul 2024 23:32:03 +0000
Subject: [PATCH 2/2] Address comments.

---
 libc/src/math/generic/hypotf.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/libc/src/math/generic/hypotf.cpp b/libc/src/math/generic/hypotf.cpp
index 478209a14f454..4ccf4119800d1 100644
--- a/libc/src/math/generic/hypotf.cpp
+++ b/libc/src/math/generic/hypotf.cpp
@@ -79,11 +79,11 @@ LLVM_LIBC_FUNCTION(float, hypotf, (float x, float y)) {
     double err = (sum_sq - r_sq.hi) + (sum_sq_lo - r_sq.lo);
 #endif
 
-    if (err > 0)
+    if (err > 0) {
       r_u |= 1;
-    else if ((err < 0) && (r_u & 1) == 0)
+    } else if ((err < 0) && (r_u & 1) == 0) {
       r_u -= 1;
-    else if ((r_u & 0x0000'0000'1FFF'FFFF) == 0) {
+    } else if ((r_u & 0x0000'0000'1FFF'FFFF) == 0) {
       // The rounded result is exact.
       fputil::clear_except_if_required(FE_INEXACT);
     }



More information about the libc-commits mailing list