[libc-commits] [libc] [libc][math][c23] Improve rsqrtf16() function (PR #160639)

Anton Shepelev via libc-commits libc-commits at lists.llvm.org
Mon Nov 17 22:00:44 PST 2025


https://github.com/amemov updated https://github.com/llvm/llvm-project/pull/160639

>From 595bb6d564f593ad6e3487add76a3e9cb96dd628 Mon Sep 17 00:00:00 2001
From: Anton Shepelev <shepelev777 at gmail.com>
Date: Wed, 24 Sep 2025 20:59:06 -0700
Subject: [PATCH 1/2] - Draft of math approximation for targets that have only
 int hardware - Fixed typo in +inf case. Should return +0 according to
 F.10.4.9

---
 libc/src/__support/math/rsqrtf16.h | 109 +++++++++++++++++++++++++++--
 1 file changed, 105 insertions(+), 4 deletions(-)

diff --git a/libc/src/__support/math/rsqrtf16.h b/libc/src/__support/math/rsqrtf16.h
index 30ab58f8a5798..e493d30d463de 100644
--- a/libc/src/__support/math/rsqrtf16.h
+++ b/libc/src/__support/math/rsqrtf16.h
@@ -58,12 +58,11 @@ LIBC_INLINE static constexpr float16 rsqrtf16(float16 x) {
       return FPBits::quiet_nan().get_val();
     }
 
-    // x = +inf => rsqrt(x) = 0
-    return FPBits::zero().get_val();
+    // x = +inf => rsqrt(x) = +0
+    return FPBits::zero(xbits.sign()).get_val();
   }
 
-  // TODO: add integer based implementation when LIBC_TARGET_CPU_HAS_FPU_FLOAT
-  // is not defined
+#ifdef LIBC_TARGET_CPU_HAS_FPU_FLOAT
   float result = 1.0f / fputil::sqrt<float>(fputil::cast<float>(x));
 
   // Targeted post-corrections to ensure correct rounding in half for specific
@@ -76,6 +75,108 @@ LIBC_INLINE static constexpr float16 rsqrtf16(float16 x) {
   }
 
   return fputil::cast<float16>(result);
+
+#else
+  // Range reduction:
+  // x can be expressed as m*2^e, where e - int exponent and m - mantissa
+  // rsqrtf16(x) = rsqrtf16(m*2^e)
+  // rsqrtf16(m*2^e) = 1/sqrt(m) * 1/sqrt(2^e) = 1/sqrt(m) * 1/2^(e/2)
+  // 1/sqrt(m) * 1/2^(e/2) = 1/sqrt(m) * 2^(-e/2)
+
+  // Compute reduction directly from half bits to avoid frexp/ldexp overhead.
+  int exponent = 0;
+  int signifcand = 0; // same as mantissa, but int
+  uint16_t eh = static_cast<uint16_t>((x_abs >> 10) & 0x1F);
+  uint16_t frac = static_cast<uint16_t>(x_abs & 0x3FF);
+
+  int result;
+  if (eh != 0) {
+    // ((2^-1 + frac/2^11) * 2) * 2^(eh-15)
+
+    // Normal: x = (1 + frac/2^10) * 2^(eh-15) = ((0.5 + frac/2^11) * 2) *
+    // 2^(eh-15)
+    // => mantissa in [0.5,1): m = 0.5 + frac/2^11, exponent = (eh - 15) + 1 =
+    // eh - 14
+    exponent = static_cast<int>(eh) - 14;
+    mantissa = 0.5f + static_cast<float>(frac) * 0x1.0p-11f;
+  } else {
+    // Subnormal: x = (frac/2^10) * 2^(1-15) = frac * 2^-24.
+    // Normalize frac so that bit 9 becomes 1; then mantissa m = (frac <<
+    // t)/2^10 ∈ [0.5,1) and exponent E = -14 - t so that x = m * 2^E.
+    if (LIBC_UNLIKELY(frac == 0)) {
+      // Should have been handled by zero check above, but keep safe.
+      return FPBits::inf(Sign::POS).get_val();
+    }
+    int shifts = 0;
+    while ((frac & 0x200u) == 0u) { // bring into [0x200, 0x3FF]
+      frac <<= 1;
+      ++shifts;
+    }
+    exponent = -14 - shifts;
+    mantissa = static_cast<float>(frac) * 0x1.0p-10f;
+  }
+
+  float result = 0.0f;
+  int exp_floored = -(exponent >> 1);
+
+  if (mantissa == 0.5f) {
+    // When mantissa is 0.5f, x was a power of 2 (or subnormal that normalizes
+    // this way). 1/sqrt(0.5f) = sqrt(2.0f).
+    // If exponent is odd (exponent = 2k + 1):
+    //   rsqrt(x) = (1/sqrt(0.5)) * 2^(-(2k+1)/2) = sqrt(2) * 2^(-k-0.5)
+    //            = sqrt(2) * 2^(-k) * (1/sqrt(2)) = 2^(-k)
+    //   exp_floored = -((2k+1)>>1) = -(k) = -k
+    //   So result = ldexp(1.0f, exp_floored)
+    // If exponent is even (exponent = 2k):
+    //   rsqrt(x) = (1/sqrt(0.5)) * 2^(-2k/2) = sqrt(2) * 2^(-k)
+    //   exp_floored = -((2k)>>1) = -(k) = -k
+    //   So result = ldexp(sqrt(2.0f), exp_floored)
+    if (exponent & 1) {
+      result = fputil::ldexp(1.0f, exp_floored);
+    } else {
+      constexpr float SQRT_2_F = 0x1.6a09e6p0f; // sqrt(2.0f)
+      result = fputil::ldexp(SQRT_2_F, exp_floored);
+    }
+  } else {
+    // 4 Degree minimax polynomial (single-precision coefficients) generated
+    // with Sollya: P = fpminimax(1/sqrt(x), 4,
+    // [|single,single,single,single,single|], [0.5;1])
+    float y = fputil::polyeval(mantissa,
+                               0x1.771256p1f,  // c0
+                               -0x1.5e7c4ap2f, // c1
+                               0x1.b3851cp2f,  // c2
+                               -0x1.1a27ep2f,  // c3
+                               0x1.265c66p0f); // c4
+
+    // Newton-Raphson iteration in float (use multiply_add to leverage FMA when
+    // available):
+    float y2 = y * y;
+    float factor = fputil::multiply_add(-0.5f * mantissa, y2, 1.5f);
+    y = y * factor;
+
+    result = fputil::ldexp(y, exp_floored);
+    if (exponent & 1) {
+      constexpr float ONE_OVER_SQRT2 = 0x1.6a09e6p-1f; // 1/sqrt(2)
+      result *= ONE_OVER_SQRT2;
+    }
+
+    // Targeted post-correction: for the specific half-precision mantissa
+    // pattern M == 0x011F we observe a consistent -1 ULP bias across exponents.
+    // Apply a tiny upward nudge to cross the rounding boundary in all modes.
+    const uint16_t half_mantissa = static_cast<uint16_t>(x_abs & 0x3ff);
+    if (half_mantissa == 0x011F) {
+      // Nudge up to fix consistent -1 ULP at that mantissa boundary
+      result = fputil::multiply_add(result, 0x1.0p-21f,
+                                    result); // result *= (1 + 2^-21)
+    } else if (half_mantissa == 0x0313) {
+      // Nudge down to fix +1 ULP under upward rounding at this mantissa
+      // boundary
+      result = fputil::multiply_add(result, -0x1.0p-21f,
+                                    result); // result *= (1 - 2^-21)
+    }
+  }
+  return fputil::cast<float16>(result);
+#endif
 }
 
 } // namespace math

>From 06f0af758528fdba3df4867727a3798139795ef4 Mon Sep 17 00:00:00 2001
From: amemov <shepelev777 at gmail.com>
Date: Mon, 17 Nov 2025 22:00:28 -0800
Subject: [PATCH 2/2] Added approximation from previous PR

---
 libc/src/__support/math/rsqrtf16.h | 45 +++++-------------------------
 1 file changed, 7 insertions(+), 38 deletions(-)

diff --git a/libc/src/__support/math/rsqrtf16.h b/libc/src/__support/math/rsqrtf16.h
index e493d30d463de..f8d9c78a91508 100644
--- a/libc/src/__support/math/rsqrtf16.h
+++ b/libc/src/__support/math/rsqrtf16.h
@@ -16,6 +16,7 @@
 #include "src/__support/FPUtil/FEnvImpl.h"
 #include "src/__support/FPUtil/FPBits.h"
 #include "src/__support/FPUtil/ManipulationFunctions.h"
+#include "src/__support/FPUtil/PolyEval.h"
 #include "src/__support/FPUtil/cast.h"
 #include "src/__support/FPUtil/multiply_add.h"
 #include "src/__support/FPUtil/sqrt.h"
@@ -77,44 +78,10 @@ LIBC_INLINE static constexpr float16 rsqrtf16(float16 x) {
   return fputil::cast<float16>(result);
 
 #else
-  // Range reduction:
-  // x can be expressed as m*2^e, where e - int exponent and m - mantissa
-  // rsqrtf16(x) = rsqrtf16(m*2^e)
-  // rsqrtf16(m*2^e) = 1/sqrt(m) * 1/sqrt(2^e) = 1/sqrt(m) * 1/2^(e/2)
-  // 1/sqrt(m) * 1/2^(e/2) = 1/sqrt(m) * 2^(-e/2)
+  float xf = fputil::cast<float>(x);
 
-  // Compute reduction directly from half bits to avoid frexp/ldexp overhead.
   int exponent = 0;
-  int signifcand = 0; // same as mantissa, but int
-  uint16_t eh = static_cast<uint16_t>((x_abs >> 10) & 0x1F);
-  uint16_t frac = static_cast<uint16_t>(x_abs & 0x3FF);
-
-  int result;
-  if (eh != 0) {
-    // ((2^-1 + frac/2^11) * 2) * 2^(eh-15)
-
-    // Normal: x = (1 + frac/2^10) * 2^(eh-15) = ((0.5 + frac/2^11) * 2) *
-    // 2^(eh-15)
-    // => mantissa in [0.5,1): m = 0.5 + frac/2^11, exponent = (eh - 15) + 1 =
-    // eh - 14
-    exponent = static_cast<int>(eh) - 14;
-    mantissa = 0.5f + static_cast<float>(frac) * 0x1.0p-11f;
-  } else {
-    // Subnormal: x = (frac/2^10) * 2^(1-15) = frac * 2^-24.
-    // Normalize frac so that bit 9 becomes 1; then mantissa m = (frac <<
-    // t)/2^10 ∈ [0.5,1) and exponent E = -14 - t so that x = m * 2^E.
-    if (LIBC_UNLIKELY(frac == 0)) {
-      // Should have been handled by zero check above, but keep safe.
-      return FPBits::inf(Sign::POS).get_val();
-    }
-    int shifts = 0;
-    while ((frac & 0x200u) == 0u) { // bring into [0x200, 0x3FF]
-      frac <<= 1;
-      ++shifts;
-    }
-    exponent = -14 - shifts;
-    mantissa = static_cast<float>(frac) * 0x1.0p-10f;
-  }
+  float mantissa = fputil::frexp(xf, exponent);
 
   float result = 0.0f;
   int exp_floored = -(exponent >> 1);
@@ -139,8 +106,9 @@ LIBC_INLINE static constexpr float16 rsqrtf16(float16 x) {
     }
   } else {
     // 4 Degree minimax polynomial (single-precision coefficients) generated
-    // with Sollya: P = fpminimax(1/sqrt(x), 4,
-    // [|single,single,single,single,single|], [0.5;1])
+    // with Sollya:
+    //   P = fpminimax(1/sqrt(x), 4,
+    //       [|single,single,single,single,single|], [0.5;1])
     float y = fputil::polyeval(mantissa,
                                0x1.771256p1f,  // c0
                                -0x1.5e7c4ap2f, // c1
@@ -175,6 +143,7 @@ LIBC_INLINE static constexpr float16 rsqrtf16(float16 x) {
                                     result); // result *= (1 - 2^-21)
     }
   }
+
   return fputil::cast<float16>(result);
 #endif
 }



More information about the libc-commits mailing list