[libc] [llvm] [libc][math] Implement C23 half precision pow function (PR #159906)

Muhammad Bassiouni via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 2 20:53:50 PST 2026


================
@@ -7,399 +7,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "src/math/powf16.h"
-#include "hdr/errno_macros.h"
-#include "hdr/fenv_macros.h"
-#include "src/__support/CPP/bit.h"
-#include "src/__support/FPUtil/FEnvImpl.h"
-#include "src/__support/FPUtil/FPBits.h"
-#include "src/__support/FPUtil/PolyEval.h"
-#include "src/__support/FPUtil/cast.h"
-#include "src/__support/FPUtil/multiply_add.h"
-#include "src/__support/FPUtil/nearest_integer.h"
-#include "src/__support/FPUtil/sqrt.h"
-#include "src/__support/common.h"
-#include "src/__support/macros/config.h"
-#include "src/__support/macros/optimization.h"
-#include "src/__support/macros/properties/types.h"
-#include "src/__support/math/common_constants.h"
-#include "src/__support/math/exp10f_utils.h"
+#include "src/__support/math/powf16.h"
 
 namespace LIBC_NAMESPACE_DECL {
 
-using namespace common_constants_internal;
-
-namespace {
-
-LIBC_INLINE static double exp2_range_reduced(double x) {
-  // k = round(x * 32)  => (hi + mid) * 2^5
-  double kf = fputil::nearest_integer(x * 32.0);
-  int k = static_cast<int>(kf);
-  // dx = lo = x - (hi + mid) = x - k * 2^(-5)
-  double dx = fputil::multiply_add(-0x1.0p-5, kf, x); // -2^-5 * k + x
-
-  // hi = k >> MID_BITS
-  // exp_hi = hi shifted into double exponent field
-  int64_t hi = static_cast<int64_t>(k >> ExpBase::MID_BITS);
-  int64_t exp_hi = static_cast<int64_t>(
-      static_cast<uint64_t>(hi) << fputil::FPBits<double>::FRACTION_LEN);
-
-  // mh_bits = bits for 2^hi * 2^mid  (lookup contains base bits for 2^mid)
-  int tab_index = k & ExpBase::MID_MASK; // mid index in [0, 31]
-  int64_t mh_bits = ExpBase::EXP_2_MID[tab_index] + exp_hi;
-
-  // mh = 2^(hi + mid)
-  double mh = fputil::FPBits<double>(static_cast<uint64_t>(mh_bits)).get_val();
-
-  // Degree-5 polynomial approximating (2^x - 1)/x generating by Sollya with:
-  // > P = fpminimax((2^x - 1)/x, 5, [|D...|], [-1/32. 1/32]);
-  constexpr double COEFFS[5] = {0x1.62e42fefa39efp-1, 0x1.ebfbdff8131c4p-3,
-                                0x1.c6b08d7061695p-5, 0x1.3b2b1bee74b2ap-7,
-                                0x1.5d88091198529p-10};
-
-  double dx_sq = dx * dx;
-  double c1 = fputil::multiply_add(dx, COEFFS[0], 1.0); // 1 + ln2*dx
-  double c2 =
-      fputil::multiply_add(dx, COEFFS[2], COEFFS[1]); // COEFF1 + COEFF2*dx
-  double c3 =
-      fputil::multiply_add(dx, COEFFS[4], COEFFS[3]); // COEFF3 + COEFF4*dx
-  double p = fputil::multiply_add(dx_sq, c3, c2);     // c2 + c3*dx^2
-
-  // 2^x = 2^(hi+mid) * 2^dx
-  //     ≈ mh * (1 + dx * P(dx))
-  //     = mh + (mh * dx) * P(dx)
-  double result = fputil::multiply_add(p, dx_sq * mh, c1 * mh);
-
-  return result;
-}
-
-LIBC_INLINE bool is_odd_integer(float16 x) {
-  using FPBits = fputil::FPBits<float16>;
-  FPBits xbits(x);
-  uint16_t x_u = xbits.uintval();
-  unsigned x_e = static_cast<unsigned>(xbits.get_biased_exponent());
-  unsigned lsb = static_cast<unsigned>(
-      cpp::countr_zero(static_cast<uint32_t>(x_u | FPBits::EXP_MASK)));
-  constexpr unsigned UNIT_EXPONENT =
-      static_cast<unsigned>(FPBits::EXP_BIAS + FPBits::FRACTION_LEN);
-  return (x_e + lsb == UNIT_EXPONENT);
-}
-
-LIBC_INLINE bool is_integer(float16 x) {
-  using FPBits = fputil::FPBits<float16>;
-  FPBits xbits(x);
-  uint16_t x_u = xbits.uintval();
-  unsigned x_e = static_cast<unsigned>(xbits.get_biased_exponent());
-  unsigned lsb = static_cast<unsigned>(
-      cpp::countr_zero(static_cast<uint32_t>(x_u | FPBits::EXP_MASK)));
-  constexpr unsigned UNIT_EXPONENT =
-      static_cast<unsigned>(FPBits::EXP_BIAS + FPBits::FRACTION_LEN);
-  return (x_e + lsb >= UNIT_EXPONENT);
-}
-
-} // namespace
-
-LLVM_LIBC_FUNCTION(float16, powf16, (float16 x, float16 y)) {
-  using FPBits = fputil::FPBits<float16>;
-
-  FPBits xbits(x), ybits(y);
-  bool x_sign = xbits.is_neg();
-  bool y_sign = ybits.is_neg();
-
-  FPBits x_abs = xbits.abs();
-  FPBits y_abs = ybits.abs();
-
-  uint16_t x_u = xbits.uintval();
-  uint16_t x_a = x_abs.uintval();
-  uint16_t y_a = y_abs.uintval();
-  uint16_t y_u = ybits.uintval();
-  bool result_sign = false;
-
-  ///////// BEGIN - Check exceptional cases ////////////////////////////////////
-  // If x or y is signaling NaN
-  if (xbits.is_signaling_nan() || ybits.is_signaling_nan()) {
-    fputil::raise_except_if_required(FE_INVALID);
-    return FPBits::quiet_nan().get_val();
-  }
-
-  if (LIBC_UNLIKELY(
-          ybits.is_zero() || x_u == FPBits::one().uintval() || xbits.is_nan() ||
-          ybits.is_nan() || x_u == FPBits::one().uintval() ||
-          x_u == FPBits::zero().uintval() || x_u >= FPBits::inf().uintval() ||
-          y_u >= FPBits::inf().uintval() ||
-          x_u < FPBits::min_normal().uintval() || y_a == 0x3400U || // 0.25
-          y_a == 0x3800U ||                                         // 0.5
-          y_a == 0x3A00U ||                                         // 0.75
-          y_a == 0x3D00U ||                                         // 1.25
-          y_a == 0x3E00U ||                                         // 1.5
-          y_a == 0x4000U ||                                         // 2.0
-          y_a == 0x4100U ||                                         // 2.5
-          y_a == 0x4300U ||                                         // 3.5
-          is_integer(y))) {
-    // pow(x, 0) = 1
-    if (ybits.is_zero()) {
-      return 1.0f16;
-    }
-
-    // pow(1, Y) = 1
-    if (x_u == FPBits::one().uintval()) {
-      return 1.0f16;
-    }
-    // 4. Handle remaining NaNs
-    // pow(NaN, y) = NaN (for y != 0)
-    if (xbits.is_nan()) {
-      return x;
-    }
-    // pow(x, NaN) = NaN (for x != 1)
-    if (ybits.is_nan()) {
-      return y;
-    }
-    switch (y_a) {
-    case 0x3400U: // y = ±0.25 (1/4)
-    case 0x3800U: // y = ±0.5 (1/2)
-    case 0x3A00U: // y = ±0.75 (3/4)
-    case 0x3D00U: // y = ±1.25 (5/4)
-    case 0x3E00U: // y = ±1.5 (3/2)
-    case 0x4100U: // y = ±2.5 (5/2)
-    case 0x4300U: // y = ±3.5 (7/2)
-    {
-      if (xbits.is_zero()) {
-        if (y_sign) {
-          // pow(±0, negative) handled below
-          break;
-        } else {
-          // pow(±0, positive_fractional) = +0
-          return FPBits::zero(Sign::POS).get_val();
-        }
-      }
-
-      if (x_sign && !xbits.is_zero()) {
-        break; // pow(negative, non-integer) = NaN
-      }
-
-      double x_d = static_cast<double>(x);
-      double sqrt_x = fputil::sqrt<double>(x_d);
-      double fourth_root = fputil::sqrt<double>(sqrt_x);
-      double result_d;
-
-      // Compute based on exponent value
-      switch (y_a) {
-      case 0x3400U: // 0.25 = x^(1/4)
-        result_d = fourth_root;
-        break;
-      case 0x3800U: // 0.5 = x^(1/2)
-        result_d = sqrt_x;
-        break;
-      case 0x3A00U: // 0.75 = x^(1/2) * x^(1/4)
-        result_d = sqrt_x * fourth_root;
-        break;
-      case 0x3D00U: // 1.25 = x * x^(1/4)
-        result_d = x_d * fourth_root;
-        break;
-      case 0x3E00U: // 1.5 = x * x^(1/2)
-        result_d = x_d * sqrt_x;
-        break;
-      case 0x4100U: // 2.5 = x^2 * x^(1/2)
-        result_d = x_d * x_d * sqrt_x;
-        break;
-      case 0x4300U: // 3.5 = x^3 * x^(1/2)
-        result_d = x_d * x_d * x_d * sqrt_x;
-        break;
-      }
-
-      result_d = y_sign ? (1.0 / result_d) : result_d;
-      return fputil::cast<float16>(result_d);
-    }
-    case 0x3c00U: // y = +-1.0
-      return fputil::cast<float16>(y_sign ? (1.0 / x) : x);
-
-    case 0x4000U: // y = +-2.0
-      double result_d = static_cast<double>(x) * static_cast<double>(x);
-      return fputil::cast<float16>(y_sign ? (1.0 / (result_d)) : (result_d));
-    }
-    // TODO: Speed things up with pow(2, y) = exp2(y) and pow(10, y) = exp10(y).
-    //
-    // pow(-1, y) for integer y
-    if (x_u == FPBits::one(Sign::NEG).uintval()) {
-      if (is_integer(y)) {
-        if (is_odd_integer(y)) {
-          return -1.0f16;
-        } else {
-          return 1.0f16;
-        }
-      }
-      // pow(-1, non-integer) = NaN
-      fputil::set_errno_if_required(EDOM);
-      fputil::raise_except_if_required(FE_INVALID);
-      return FPBits::quiet_nan().get_val();
-    }
-
-    // pow(±0, y) cases
-    if (xbits.is_zero()) {
-      if (y_sign) {
-        // pow(+-0, negative) = +-inf and raise FE_DIVBYZERO
-        fputil::raise_except_if_required(FE_DIVBYZERO);
-        bool result_neg = x_sign && ybits.is_finite() && is_odd_integer(y);
-        return FPBits::inf(result_neg ? Sign::NEG : Sign::POS).get_val();
-      } else {
-        // pow(+-0, positive) = +-0
-        bool out_is_neg = x_sign && is_odd_integer(y);
-        return out_is_neg ? FPBits::zero(Sign::NEG).get_val()
-                          : FPBits::zero(Sign::POS).get_val();
-      }
-    }
-
-    if (xbits.is_inf()) {
-      bool out_is_neg = x_sign && ybits.is_finite() && is_odd_integer(y);
-      if (y_sign) // pow(+-inf, negative) = +-0
-        return out_is_neg ? FPBits::zero(Sign::NEG).get_val()
-                          : FPBits::zero(Sign::POS).get_val();
-      // pow(+-inf, positive) = +-inf
-      return FPBits::inf(out_is_neg ? Sign::NEG : Sign::POS).get_val();
-    }
-
-    // y = +-inf cases
-    if (ybits.is_inf()) {
-      // pow(1, inf) handled above.
-      bool x_abs_less_than_one = x_a < FPBits::one().uintval();
-      if ((x_abs_less_than_one && !y_sign) ||
-          (!x_abs_less_than_one && y_sign)) {
-        // |x| < 1 and y = +inf => 0.0
-        // |x| > 1 and y = -inf => 0.0
-        return 0.0f16;
-      } else {
-        // |x| > 1 and y = +inf => +inf
-        // |x| < 1 and y = -inf => +inf
-        return FPBits::inf(Sign::POS).get_val();
-      }
-    }
-
-    // pow( negative, non-integer ) = NaN
-    if (x_sign && !is_integer(y)) {
-      fputil::set_errno_if_required(EDOM);
-      fputil::raise_except_if_required(FE_INVALID);
-      return FPBits::quiet_nan().get_val();
-    }
-
-    bool result_sign = false;
-    if (x_sign && is_integer(y)) {
-      result_sign = is_odd_integer(y);
-    }
-
-    if (is_integer(y)) {
-      double base = x_abs.get_val();
-      double res = 1.0;
-      int yi = static_cast<int>(y_abs.get_val());
-
-      // Fast exponentiation by squaring
-      while (yi > 0) {
-        if (yi & 1)
-          res *= base;
-        base *= base;
-        yi = yi >> 1;
-      }
-
-      if (y_sign) {
-        res = 1.0 / res;
-      }
-
-      if (result_sign) {
-        res = -res;
-      }
-
-      if (FPBits(fputil::cast<float16>(res)).is_inf()) {
-        fputil::raise_except_if_required(FE_OVERFLOW);
-        res = result_sign ? -0x1.0p20 : 0x1.0p20;
-      }
-
-      float16 final_res = fputil::cast<float16>(res);
-      return final_res;
-    }
-  }
-
-  ///////// END - Check exceptional cases //////////////////////////////////////
-
-  // Core computation: x^y = 2^( y * log2(x) )
-  // We compute log2(x) = log(x) / log(2) using a polynomial approximation.
-
-  // The exponent part (m) is added later to get the final log(x).
-  FPBits x_bits(x);
-  uint16_t x_u_log = x_bits.uintval();
-
-  // Extract exponent field of x.
-  int m = x_bits.get_exponent();
-
-  // When x is subnormal, normalize it by adjusting m.
-  if ((x_u_log & FPBits::EXP_MASK) == 0U) {
-    unsigned leading_zeros =
-        cpp::countl_zero(static_cast<uint32_t>(x_u_log)) - (32 - 16);
-
-    constexpr unsigned SUBNORMAL_SHIFT_CORRECTION = 5;
-    unsigned shift = leading_zeros - SUBNORMAL_SHIFT_CORRECTION;
-
-    x_bits.set_mantissa(static_cast<uint16_t>(x_u_log << shift));
-
-    m = 1 - FPBits::EXP_BIAS - static_cast<int>(shift);
-  }
-
-  // Extract the mantissa and index into small lookup tables.
-  uint16_t mant = x_bits.get_mantissa();
-  // Use the highest 7 fractional bits of the mantissa as the index f.
-  int f = mant >> (FPBits::FRACTION_LEN - 7);
-
-  // Reconstruct the mantissa value m_x so it's in the range [1.0, 2.0).
-  x_bits.set_biased_exponent(FPBits::EXP_BIAS);
-  double mant_d = x_bits.get_val();
-  // Degree-5 polynomial approximation
-  // of log2 generated by Sollya with:
-  // > P = fpminimax(log2(1 + x)/x, 4, [|1, D...|], [-2^-8, 2^-7]);
-  constexpr double COEFFS[5] = {0x1.71547652b8133p0, -0x1.71547652d1e33p-1,
-                                0x1.ec70a098473dep-2, -0x1.7154c5ccdf121p-2,
-                                0x1.2514fd90a130ap-2};
-
-#ifdef LIBC_TARGET_CPU_HAS_FMA_DOUBLE
-  double v = fputil::multiply_add<double>(mant_d, RD[f], -1.0);
-#else
-  double c = fputil::FPBits<double>(fputil::FPBits<double>(mant_d).uintval() &
-                                    0x3fff'e000'0000'0000)
-                 .get_val();
-  double v = fputil::multiply_add(RD[f], mant_d - c, CD[f]);
-#endif // LIBC_TARGET_CPU_HAS_FMA_DOUBLE
-  double extra_factor = static_cast<double>(m) + LOG2_R[f];
-  double vsq = v * v;
-  double c0 = fputil::multiply_add(v, COEFFS[0], 0.0);
-  double c1 = fputil::multiply_add(v, COEFFS[2], COEFFS[1]);
-  double c2 = fputil::multiply_add(v, COEFFS[4], COEFFS[3]);
-
-  double log2_x = fputil::polyeval(vsq, c0, c1, c2);
-
-  double y_d = fputil::cast<double>(y);
-  double z = fputil::multiply_add(y_d, log2_x, y_d * extra_factor);
-
-  // Check for underflow
-  // Float16 min normal is 2^-14, smallest subnormal is 2^-24
-  if (LIBC_UNLIKELY(z < -25.0)) {
-    fputil::raise_except_if_required(FE_UNDERFLOW);
-    return result_sign ? FPBits::zero(Sign::NEG).get_val()
-                       : FPBits::zero(Sign::POS).get_val();
-  }
-
-  // Check for overflow
-  // Float16 max is ~2^16
-  double result_d;
-  if (LIBC_UNLIKELY(z > 16.0)) {
-    fputil::raise_except_if_required(FE_OVERFLOW);
-    result_d = result_sign ? -0x1.0p20 : 0x1.0p20;
-  } else {
-    result_d = exp2_range_reduced(z);
-  }
-
-  if (result_sign) {
-
-    result_d = -result_d;
-  }
-
-  float16 result = fputil::cast<float16>((result_d));
-  return result;
-}
+LLVM_LIBC_FUNCTION(float16, powf16, (float16 x,float16 y)) { return math::powf16(x,y); }
----------------
bassiounix wrote:

Make sure to use `clang-format`.

https://github.com/llvm/llvm-project/pull/159906


More information about the llvm-commits mailing list