[libc-commits] [libc] [libc][math] Fix overflow shifts for dyadic floats and add skip accuracy option for expm1. (PR #98048)

via libc-commits libc-commits at lists.llvm.org
Tue Jul 9 12:20:54 PDT 2024


https://github.com/lntue updated https://github.com/llvm/llvm-project/pull/98048

>From 0b8f7d4ba4db21a7676ea4c02c80875dc5f53ff8 Mon Sep 17 00:00:00 2001
From: Tue Ly <lntue.h at gmail.com>
Date: Mon, 8 Jul 2024 16:49:33 +0000
Subject: [PATCH 1/3] [libc][math] Fix overflow shifts for dyadic floats and
 add skip accuracy option for expm1.

---
 libc/src/__support/FPUtil/dyadic_float.h | 17 ++++++---
 libc/src/math/generic/expm1.cpp          | 31 ++++++++++------
 libc/test/src/math/expm1_test.cpp        | 45 +++++++++---------------
 libc/test/src/math/tan_test.cpp          | 18 +++++-----
 4 files changed, 61 insertions(+), 50 deletions(-)

diff --git a/libc/src/__support/FPUtil/dyadic_float.h b/libc/src/__support/FPUtil/dyadic_float.h
index 8d44a98a693f8..79fb9c362ed69 100644
--- a/libc/src/__support/FPUtil/dyadic_float.h
+++ b/libc/src/__support/FPUtil/dyadic_float.h
@@ -260,10 +260,19 @@ LIBC_INLINE constexpr DyadicFloat<Bits> quick_add(DyadicFloat<Bits> a,
     return a;
 
   // Align exponents
-  if (a.exponent > b.exponent)
-    b.shift_right(a.exponent - b.exponent);
-  else if (b.exponent > a.exponent)
-    a.shift_right(b.exponent - a.exponent);
+  if (a.exponent > b.exponent) {
+    size_t shift = static_cast<size_t>(a.exponent - b.exponent);
+    if (shift < Bits)
+      b.shift_right(static_cast<int>(shift));
+    else
+      b = DyadicFloat<Bits>();
+  } else if (b.exponent > a.exponent) {
+    size_t shift = static_cast<size_t>(b.exponent - a.exponent);
+    if (shift < Bits)
+      a.shift_right(static_cast<int>(shift));
+    else
+      a = DyadicFloat<Bits>();
+  }
 
   DyadicFloat<Bits> result;
 
diff --git a/libc/src/math/generic/expm1.cpp b/libc/src/math/generic/expm1.cpp
index 574c4b9aaf39f..150c0bbcf60da 100644
--- a/libc/src/math/generic/expm1.cpp
+++ b/libc/src/math/generic/expm1.cpp
@@ -25,7 +25,9 @@
 #include "src/__support/integer_literals.h"
 #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
 
-#include <errno.h>
+#if ((LIBC_MATH & LIBC_MATH_SKIP_ACCURATE_PASS) != 0)
+#define LIBC_MATH_EXPM1_SKIP_ACCURATE_PASS
+#endif
 
 // #define DEBUGDEBUG
 
@@ -51,7 +53,7 @@ constexpr double LOG2_E = 0x1.71547652b82fep+0;
 constexpr uint64_t ERR_D = 0x3c08000000000000;
 // Errors when using double-double precision.
 // 0x1.0p-99
-constexpr uint64_t ERR_DD = 0x39c0000000000000;
+[[maybe_unused]] constexpr uint64_t ERR_DD = 0x39c0000000000000;
 
 // -2^-12 * log(2)
 // > a = -2^-12 * log(2);
@@ -108,7 +110,7 @@ DoubleDouble poly_approx_dd(const DoubleDouble &dx) {
 // Return (exp(dx) - 1)/dx ~ 1 + dx / 2 + dx^2 / 6 + ... + dx^6 / 5040
 // For |dx| < 2^-13 + 2^-30:
 //   | output - exp(dx) | < 2^-126.
-Float128 poly_approx_f128(const Float128 &dx) {
+[[maybe_unused]] Float128 poly_approx_f128(const Float128 &dx) {
   constexpr Float128 COEFFS_128[]{
       {Sign::POS, -127, 0x80000000'00000000'00000000'00000000_u128}, // 1.0
       {Sign::POS, -128, 0x80000000'00000000'00000000'00000000_u128}, // 0.5
@@ -127,13 +129,14 @@ Float128 poly_approx_f128(const Float128 &dx) {
 
 #ifdef DEBUGDEBUG
 std::ostream &operator<<(std::ostream &OS, const Float128 &r) {
-  OS << (r.sign ? "-(" : "(") << r.mantissa.val[0] << " + " << r.mantissa.val[1]
-     << " * 2^64) * 2^" << r.exponent << "\n";
+  OS << (r.sign == Sign::NEG ? "-(" : "(") << r.mantissa.val[0] << " + "
+     << r.mantissa.val[1] << " * 2^64) * 2^" << r.exponent << "\n";
   return OS;
 }
 
 std::ostream &operator<<(std::ostream &OS, const DoubleDouble &r) {
-  OS << std::hexfloat << r.hi << " + " << r.lo << std::defaultfloat << "\n";
+  OS << std::hexfloat << "(" << r.hi << " + " << r.lo << ")"
+     << std::defaultfloat << "\n";
   return OS;
 }
 #endif
@@ -141,7 +144,7 @@ std::ostream &operator<<(std::ostream &OS, const DoubleDouble &r) {
 // Compute exp(x) - 1 using 128-bit precision.
 // TODO(lntue): investigate triple-double precision implementation for this
 // step.
-Float128 expm1_f128(double x, double kd, int idx1, int idx2) {
+[[maybe_unused]] Float128 expm1_f128(double x, double kd, int idx1, int idx2) {
   // Recalculate dx:
 
   double t1 = fputil::multiply_add(kd, MLOG_2_EXP2_M12_HI, x); // exact
@@ -182,9 +185,10 @@ Float128 expm1_f128(double x, double kd, int idx1, int idx2) {
 #ifdef DEBUGDEBUG
   std::cout << "=== VERY SLOW PASS ===\n"
             << "        kd: " << kd << "\n"
-            << "        dx: " << dx << "exp_mid_m1: " << exp_mid_m1
-            << "   exp_mid: " << exp_mid << "         p: " << p
-            << "         r: " << r << std::endl;
+            << "        hi: " << hi << "\n"
+            << " minus_one: " << minus_one << "        dx: " << dx
+            << "exp_mid_m1: " << exp_mid_m1 << "   exp_mid: " << exp_mid
+            << "         p: " << p << "         r: " << r << std::endl;
 #endif
 
   return r;
@@ -479,6 +483,12 @@ LLVM_LIBC_FUNCTION(double, expm1, (double x)) {
   // Use double-double
   DoubleDouble r_dd = exp_double_double(x, kd, exp_mid, hi_part);
 
+#ifdef LIBC_MATH_EXPM1_SKIP_ACCURATE_PASS
+  int64_t exp_hi = static_cast<int64_t>(hi) << FPBits::FRACTION_LEN;
+  double r =
+      cpp::bit_cast<double>(exp_hi + cpp::bit_cast<int64_t>(r_dd.hi + r_dd.lo));
+  return r;
+#else
   double err_dd = cpp::bit_cast<double>(ERR_DD + err);
 
   double upper_dd = r_dd.hi + (r_dd.lo + err_dd);
@@ -494,6 +504,7 @@ LLVM_LIBC_FUNCTION(double, expm1, (double x)) {
   Float128 r_f128 = expm1_f128(x, kd, idx1, idx2);
 
   return static_cast<double>(r_f128);
+#endif // LIBC_MATH_EXPM1_SKIP_ACCURATE_PASS
 }
 
 } // namespace LIBC_NAMESPACE
diff --git a/libc/test/src/math/expm1_test.cpp b/libc/test/src/math/expm1_test.cpp
index 1bf07f19f3a7c..df5c08864bb8a 100644
--- a/libc/test/src/math/expm1_test.cpp
+++ b/libc/test/src/math/expm1_test.cpp
@@ -14,7 +14,6 @@
 #include "test/UnitTest/Test.h"
 #include "utils/MPFRWrapper/MPFRUtils.h"
 
-#include <errno.h>
 #include <stdint.h>
 
 using LlvmLibcExpm1Test = LIBC_NAMESPACE::testing::FPTest<double>;
@@ -23,34 +22,24 @@ namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
 using LIBC_NAMESPACE::testing::tlog;
 
 TEST_F(LlvmLibcExpm1Test, TrickyInputs) {
-  constexpr int N = 21;
-  constexpr uint64_t INPUTS[N] = {
-      0x3FD79289C6E6A5C0, // x=0x1.79289c6e6a5cp-2
-      0x3FD05DE80A173EA0, // x=0x1.05de80a173eap-2
-      0xbf1eb7a4cb841fcc, // x=-0x1.eb7a4cb841fccp-14
-      0xbf19a61fb925970d, // x=-0x1.9a61fb925970dp-14
-      0x3fda7b764e2cf47a, // x=0x1.a7b764e2cf47ap-2
-      0xc04757852a4b93aa, // x=-0x1.757852a4b93aap+5
-      0x4044c19e5712e377, // x=0x1.4c19e5712e377p+5
-      0xbf19a61fb925970d, // x=-0x1.9a61fb925970dp-14
-      0xc039a74cdab36c28, // x=-0x1.9a74cdab36c28p+4
-      0xc085b3e4e2e3bba9, // x=-0x1.5b3e4e2e3bba9p+9
-      0xc086960d591aec34, // x=-0x1.6960d591aec34p+9
-      0xc086232c09d58d91, // x=-0x1.6232c09d58d91p+9
-      0xc0874910d52d3051, // x=-0x1.74910d52d3051p9
-      0xc0867a172ceb0990, // x=-0x1.67a172ceb099p+9
-      0xc08ff80000000000, // x=-0x1.ff8p+9
-      0xbc971547652b82fe, // x=-0x1.71547652b82fep-54
-      0xbce465655f122ff6, // x=-0x1.465655f122ff6p-49
-      0x3d1bc8ee6b28659a, // x=0x1.bc8ee6b28659ap-46
-      0x3f18442b169f672d, // x=0x1.8442b169f672dp-14
-      0xc02b4f0cfb15ca0f, // x=-0x1.b4f0cfb15ca0fp+3
-      0xc042b708872320dd, // x=-0x1.2b708872320ddp+5
+  constexpr double INPUTS[] = {
+      0x1.71547652b82fep-54, 0x1.465655f122ff6p-49, 0x1.bc8ee6b28659ap-46,
+      0x1.8442b169f672dp-14, 0x1.9a61fb925970dp-14, 0x1.eb7a4cb841fccp-14,
+      0x1.05de80a173eap-2,   0x1.79289c6e6a5cp-2,   0x1.a7b764e2cf47ap-2,
+      0x1.b4f0cfb15ca0fp+3,  0x1.9a74cdab36c28p+4,  0x1.2b708872320ddp+5,
+      0x1.4c19e5712e377p+5,  0x1.757852a4b93aap+5,  0x1.77f74111e0894p+6,
+      0x1.a6c3780bbf824p+6,  0x1.e3d57e4c557f6p+6,  0x1.f07560077985ap+6,
+      0x1.1f0da93354198p+7,  0x1.71018579c0758p+7,  0x1.204684c1167e9p+8,
+      0x1.5b3e4e2e3bba9p+9,  0x1.6232c09d58d91p+9,  0x1.67a172ceb099p+9,
+      0x1.6960d591aec34p+9,  0x1.74910d52d3051p+9,  0x1.ff8p+9,
   };
+  constexpr int N = sizeof(INPUTS) / sizeof(INPUTS[0]);
   for (int i = 0; i < N; ++i) {
-    double x = FPBits(INPUTS[i]).get_val();
+    double x = INPUTS[i];
     EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Expm1, x,
                                    LIBC_NAMESPACE::expm1(x), 0.5);
+    EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Expm1, -x,
+                                   LIBC_NAMESPACE::expm1(-x), 0.5);
   }
 }
 
@@ -98,10 +87,10 @@ TEST_F(LlvmLibcExpm1Test, InDoubleRange) {
         }
       }
     }
-    tlog << " Expm1 failed: " << fails << "/" << count << "/" << cc
-         << " tests.\n";
-    tlog << "   Max ULPs is at most: " << static_cast<uint64_t>(tol) << ".\n";
     if (fails) {
+      tlog << " Expm1 failed: " << fails << "/" << count << "/" << cc
+           << " tests.\n";
+      tlog << "   Max ULPs is at most: " << static_cast<uint64_t>(tol) << ".\n";
       EXPECT_MPFR_MATCH(mpfr::Operation::Expm1, mx, mr, 0.5, rounding_mode);
     }
   };
diff --git a/libc/test/src/math/tan_test.cpp b/libc/test/src/math/tan_test.cpp
index e9e3e59f4d12d..80d57939a4f61 100644
--- a/libc/test/src/math/tan_test.cpp
+++ b/libc/test/src/math/tan_test.cpp
@@ -22,14 +22,15 @@ TEST_F(LlvmLibcTanTest, TrickyInputs) {
   constexpr double INPUTS[] = {
       0x1.d130383d17321p-27,   0x1.8000000000009p-23,  0x1.8000000000024p-22,
       0x1.800000000009p-21,    0x1.20000000000f3p-20,  0x1.800000000024p-20,
-      0x1.e0000000001c2p-20,   0x1.0da8cc189b47dp-10,  0x1.00a33764a0a83p-7,
-      0x1.911a18779813fp-7,    0x1.940c877fb7dacp-7,   0x1.f42fb19b5b9b2p-6,
-      0x1.0285070f9f1bcp-5,    0x1.6ca9ef729af76p-1,   0x1.23f40dccdef72p+0,
-      0x1.43cf16358c9d7p+0,    0x1.addf3b9722265p+0,   0x1.ae78d360afa15p+0,
-      0x1.fe81868fc47fep+1,    0x1.e31b55306f22cp+2,   0x1.e639103a05997p+2,
-      0x1.f7898d5a756ddp+2,    0x1.1685973506319p+3,   0x1.5f09cad750ab1p+3,
-      0x1.aaf85537ea4c7p+3,    0x1.4f2b874135d27p+4,   0x1.13114266f9764p+4,
-      0x1.a211877de55dbp+4,    0x1.a5eece87e8606p+4,   0x1.a65d441ea6dcep+4,
+      0x1.e0000000001c2p-20,   0x1.00452f0e0134dp-13,  0x1.0da8cc189b47dp-10,
+      0x1.00a33764a0a83p-7,    0x1.911a18779813fp-7,   0x1.940c877fb7dacp-7,
+      0x1.f42fb19b5b9b2p-6,    0x1.0285070f9f1bcp-5,   0x1.89f0f5241255bp-2,
+      0x1.6ca9ef729af76p-1,    0x1.23f40dccdef72p+0,   0x1.43cf16358c9d7p+0,
+      0x1.addf3b9722265p+0,    0x1.ae78d360afa15p+0,   0x1.fe81868fc47fep+1,
+      0x1.e31b55306f22cp+2,    0x1.e639103a05997p+2,   0x1.f7898d5a756ddp+2,
+      0x1.1685973506319p+3,    0x1.5f09cad750ab1p+3,   0x1.aaf85537ea4c7p+3,
+      0x1.4f2b874135d27p+4,    0x1.13114266f9764p+4,   0x1.a211877de55dbp+4,
+      0x1.a5eece87e8606p+4,    0x1.a65d441ea6dcep+4,   0x1.045457ae3994p+5,
       0x1.1ffb509f3db15p+5,    0x1.2345d1e090529p+5,   0x1.c96e28eb679f8p+5,
       0x1.da1838053b866p+5,    0x1.be886d9c2324dp+6,   0x1.ab514bfc61c76p+7,
       0x1.14823229799c2p+7,    0x1.48ff1782ca91dp+8,   0x1.dcbfda0c7559ep+8,
@@ -42,6 +43,7 @@ TEST_F(LlvmLibcTanTest, TrickyInputs) {
       0x1.6ac5b262ca1ffp+843,  0x1.8bb5847d49973p+845, 0x1.6ac5b262ca1ffp+849,
       0x1.f08b14e1c4d0fp+890,  0x1.2b5fe88a9d8d5p+903, 0x1.a880417b7b119p+1023,
       0x1.f6d7518808571p+1023,
+
   };
   constexpr int N = sizeof(INPUTS) / sizeof(INPUTS[0]);
 

>From d4ab2c1d9ec30a1f1a8e5b43095c0d7a43f1fd0f Mon Sep 17 00:00:00 2001
From: Tue Ly <lntue.h at gmail.com>
Date: Tue, 9 Jul 2024 15:53:54 +0000
Subject: [PATCH 2/3] Address comments.

---
 libc/src/__support/FPUtil/dyadic_float.h | 35 ++++++++++++------------
 1 file changed, 18 insertions(+), 17 deletions(-)

diff --git a/libc/src/__support/FPUtil/dyadic_float.h b/libc/src/__support/FPUtil/dyadic_float.h
index 79fb9c362ed69..3c80e25350ba8 100644
--- a/libc/src/__support/FPUtil/dyadic_float.h
+++ b/libc/src/__support/FPUtil/dyadic_float.h
@@ -68,15 +68,25 @@ template <size_t Bits> struct DyadicFloat {
 
   // Used for aligning exponents.  Output might not be normalized.
   LIBC_INLINE constexpr DyadicFloat &shift_left(int shift_length) {
-    exponent -= shift_length;
-    mantissa <<= static_cast<size_t>(shift_length);
+    if (shift_length < Bits) {
+      exponent -= shift_length;
+      mantissa <<= static_cast<size_t>(shift_length);
+    } else {
+      exponent = 0;
+      mantissa = MantissaType(0);
+    }
     return *this;
   }
 
   // Used for aligning exponents.  Output might not be normalized.
   LIBC_INLINE constexpr DyadicFloat &shift_right(int shift_length) {
-    exponent += shift_length;
-    mantissa >>= static_cast<size_t>(shift_length);
+    if (shift_length < Bits) {
+      exponent += shift_length;
+      mantissa >>= static_cast<size_t>(shift_length);
+    } else {
+      exponent = 0;
+      mantissa = MantissaType(0);
+    }
     return *this;
   }
 
@@ -260,19 +270,10 @@ LIBC_INLINE constexpr DyadicFloat<Bits> quick_add(DyadicFloat<Bits> a,
     return a;
 
   // Align exponents
-  if (a.exponent > b.exponent) {
-    size_t shift = static_cast<size_t>(a.exponent - b.exponent);
-    if (shift < Bits)
-      b.shift_right(static_cast<int>(shift));
-    else
-      b = DyadicFloat<Bits>();
-  } else if (b.exponent > a.exponent) {
-    size_t shift = static_cast<size_t>(b.exponent - a.exponent);
-    if (shift < Bits)
-      a.shift_right(static_cast<int>(shift));
-    else
-      a = DyadicFloat<Bits>();
-  }
+  if (a.exponent > b.exponent)
+    b.shift_right(a.exponent - b.exponent);
+  else if (b.exponent > a.exponent)
+    a.shift_right(b.exponent - a.exponent);
 
   DyadicFloat<Bits> result;
 

>From 46636710aec6bdabc94185cfa7f927241291553f Mon Sep 17 00:00:00 2001
From: Tue Ly <lntue.h at gmail.com>
Date: Tue, 9 Jul 2024 19:20:13 +0000
Subject: [PATCH 3/3] Fix signed/unsigned mismatch.

---
 libc/src/__support/FPUtil/dyadic_float.h | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/libc/src/__support/FPUtil/dyadic_float.h b/libc/src/__support/FPUtil/dyadic_float.h
index 3c80e25350ba8..32267bb68e1cc 100644
--- a/libc/src/__support/FPUtil/dyadic_float.h
+++ b/libc/src/__support/FPUtil/dyadic_float.h
@@ -67,10 +67,10 @@ template <size_t Bits> struct DyadicFloat {
   }
 
   // Used for aligning exponents.  Output might not be normalized.
-  LIBC_INLINE constexpr DyadicFloat &shift_left(int shift_length) {
+  LIBC_INLINE constexpr DyadicFloat &shift_left(unsigned shift_length) {
     if (shift_length < Bits) {
-      exponent -= shift_length;
-      mantissa <<= static_cast<size_t>(shift_length);
+      exponent -= static_cast<int>(shift_length);
+      mantissa <<= shift_length;
     } else {
       exponent = 0;
       mantissa = MantissaType(0);
@@ -79,10 +79,10 @@ template <size_t Bits> struct DyadicFloat {
   }
 
   // Used for aligning exponents.  Output might not be normalized.
-  LIBC_INLINE constexpr DyadicFloat &shift_right(int shift_length) {
+  LIBC_INLINE constexpr DyadicFloat &shift_right(unsigned shift_length) {
     if (shift_length < Bits) {
-      exponent += shift_length;
-      mantissa >>= static_cast<size_t>(shift_length);
+      exponent += static_cast<int>(shift_length);
+      mantissa >>= shift_length;
     } else {
       exponent = 0;
       mantissa = MantissaType(0);
@@ -271,9 +271,9 @@ LIBC_INLINE constexpr DyadicFloat<Bits> quick_add(DyadicFloat<Bits> a,
 
   // Align exponents
   if (a.exponent > b.exponent)
-    b.shift_right(a.exponent - b.exponent);
+    b.shift_right(static_cast<unsigned>(a.exponent - b.exponent));
   else if (b.exponent > a.exponent)
-    a.shift_right(b.exponent - a.exponent);
+    a.shift_right(static_cast<unsigned>(b.exponent - a.exponent));
 
   DyadicFloat<Bits> result;
 



More information about the libc-commits mailing list