[libc-commits] [libc] [llvm] [libc][math] Impl bfloat16 lgamma function. (PR #199312)
via libc-commits
libc-commits at lists.llvm.org
Wed Jun 17 09:32:54 PDT 2026
================
@@ -0,0 +1,196 @@
+//===-- Implementation of lgammabf16 ----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC___SUPPORT_MATH_LGAMMABF16_H
+#define LLVM_LIBC_SRC___SUPPORT_MATH_LGAMMABF16_H
+
+#include "hdr/errno_macros.h"
+#include "hdr/fenv_macros.h"
+#include "src/__support/FPUtil/FEnvImpl.h"
+#include "src/__support/FPUtil/FPBits.h"
+#include "src/__support/FPUtil/PolyEval.h"
+#include "src/__support/FPUtil/bfloat16.h"
+#include "src/__support/FPUtil/cast.h"
+#include "src/__support/macros/config.h"
+#include "src/__support/macros/optimization.h"
+#include "src/__support/math/log.h"
+
+namespace LIBC_NAMESPACE_DECL {
+namespace math {
+
+// lgamma_positive_d: compute lgamma(x) for x > 0, returning double.
+//
+// Takes double so callers can pass (1.0 + ax) without float precision loss.
+//
+// For x < 8, applies the recurrence lgamma(x) = lgamma(x+1) - ln(x) until
+// x reaches [4, 8), then evaluates the polynomial. This is critical because
+// the [2,3) polynomial has max_rel_err=9.98e-6 which, near its edges (t near
+// +0.5 or -0.5), causes ~6e-6 absolute error. After subtracting ln(x), the
+// result near x=1 or x=2 can be as small as ~0.002, giving ~0.45 ULP error --
+// which fails the directed-rounding tolerance test. Polynomials for [4,5) and
+// above have max_rel_err <= 6.61e-7, keeping final error well under 0.1 ULP.
+LIBC_INLINE double lgamma_positive_d(double x) {
+ // Coefficients for lgamma on [n, n+1), centered at n+0.5.
+ // Each row: {c0, c1, c2, c3, c4} for fputil::polyeval(t, c[0]..c[4])
+ // where t = x - (n + 0.5), n = 1..7.
+ // Maximum relative errors per interval.
+ static constexpr float LGAMMA_POLY[7][5] = {
+ // [1,2), center=1.5, max_relative_err=1.29e-04
+ {-0x1.eeb280p-4f, 0x1.2f128ap-5f, 0x1.de1488p-2f, -0x1.2d373cp-3f,
+ 0x1.08d8aep-4f},
+ // [2,3), center=2.5, max_rel_err=9.98e-06
+ {0x1.2383fcp-2f, 0x1.6809bep-1f, 0x1.f61322p-3f, -0x1.48c82ap-5f,
+ 0x1.3b3cccp-7f},
+ // [3,4), center=3.5, max_rel_err=2.06e-06
+ {0x1.337302p+0f, 0x1.1a6912p+0f, 0x1.52475cp-3f, -0x1.2a2876p-6f,
+ 0x1.859be8p-9f},
+ // [4,5), center=4.5, max_rel_err=6.61e-07
+ {0x1.3a140ap+1f, 0x1.638d3ep+0f, 0x1.fd62a0p-4f, -0x1.51efeep-7f,
+ 0x1.4e023cp-10f},
+ // [5,6), center=5.5, max_rel_err=2.72e-07
+ {0x1.fa99a6p+1f, 0x1.9c70aep+0f, 0x1.984080p-4f, -0x1.b21838p-8f,
+ 0x1.58a5b2p-11f},
+ // [6,7), center=6.5, max_rel_err=1.32e-07
+ {0x1.6a676ap+2f, 0x1.cafc46p+0f, 0x1.548cdap-4f, -0x1.2e0b0cp-8f,
+ 0x1.909642p-12f},
+ // [7,8), center=7.5, max_rel_err=7.10e-08
+ {0x1.e23306p+2f, 0x1.f25eb8p+0f, 0x1.2413bep-4f, -0x1.bc5850p-9f,
+ 0x1.f9d4bep-13f},
+ };
+
+ if (LIBC_UNLIKELY(x == 1.0 || x == 2.0))
+ return 0.0;
+
+ if (x >= 8.0) {
+ // Stirling series; 0.5*ln(2*pi)
+ constexpr double HALF_LN_2PI = 0x1.d67f1c864beb4p-1;
+ double lx = math::log(x);
+ double x2 = x * x;
+ double result = (x - 0.5) * lx - x + HALF_LN_2PI;
+ result += 1.0 / (12.0 * x) - 1.0 / (360.0 * x * x2);
+ return result;
+ }
+
+ // For x in (0, 4): apply recurrence until reaching [4, 8).
+ // Using the [4,5) polynomial (max_rel_err=6.61e-7) avoids the large errors
+ // near the edges of the [2,3) polynomial (max_rel_err=9.98e-6).
+ double log_product = 0.0;
+ double xs = x;
+ while (xs < 4.0) {
+ log_product += math::log(xs);
+ xs += 1.0;
+ }
+ // xs is now in [4, 8); cast to float for polynomial evaluation.
+ float xf = static_cast<float>(xs);
+ int n = static_cast<int>(xf);
+ if (n >= 7)
+ n = 7;
+ float t = xf - (static_cast<float>(n) + 0.5f);
+ const float *c = LGAMMA_POLY[n - 1];
+ double lgamma_xs =
+ static_cast<double>(fputil::polyeval(t, c[0], c[1], c[2], c[3], c[4]));
+ return lgamma_xs - log_product;
+}
+
+LIBC_INLINE bfloat16 lgammabf16(bfloat16 x) {
+ using FPBits = fputil::FPBits<bfloat16>;
+ FPBits x_bits(x);
+
+ // Handle NaN
+ if (LIBC_UNLIKELY(x_bits.is_nan())) {
+ if (x_bits.is_signaling_nan()) {
+ fputil::raise_except_if_required(FE_INVALID);
+ return FPBits::quiet_nan().get_val();
+ }
+ return x;
+ }
+
+ uint16_t x_u = x_bits.uintval();
+ uint16_t x_abs = x_u & 0x7fffU;
+
+ // +Inf or -Inf -> +Inf
+ if (LIBC_UNLIKELY(x_abs == 0x7f80U))
+ return FPBits::inf(Sign::POS).get_val();
+
+ // +-0 -> +Inf (pole error)
+ if (LIBC_UNLIKELY(x_abs == 0U)) {
+ fputil::set_errno_if_required(ERANGE);
+ fputil::raise_except_if_required(FE_DIVBYZERO);
+ return FPBits::inf(Sign::POS).get_val();
+ }
+
+ float xf = static_cast<float>(x);
+
+ // Negative integers -> +Inf (pole error)
+ if (LIBC_UNLIKELY(x_bits.is_neg())) {
+ int biased_exp = x_abs >> FPBits::FRACTION_LEN;
+ if (biased_exp >= FPBits::EXP_BIAS) {
+ int e = biased_exp - FPBits::EXP_BIAS;
+ if (e >= FPBits::FRACTION_LEN ||
+ (x_bits.get_mantissa() &
+ static_cast<uint16_t>((1U << (FPBits::FRACTION_LEN - e)) - 1U)) ==
+ 0U) {
+ fputil::set_errno_if_required(ERANGE);
+ fputil::raise_except_if_required(FE_DIVBYZERO);
+ return FPBits::inf(Sign::POS).get_val();
+ }
+ }
+
+ // Negative non-integer: reflection formula
+ // lgamma(x) = ln(pi) - ln|sin(pi*x)| - lgamma(1-x)
+ constexpr double LN_PI_D = 0x1.250d048e7a1bdp+0;
+ float ax = -xf;
+ float frac = ax - static_cast<float>(static_cast<int>(ax));
+
+ // Map frac to [0, 0.5] to guarantee sin is positive
+ if (frac > 0.5f)
+ frac = 1.0f - frac;
+
+ // sin(pi*frac) in double via Taylor series for sin(x)/x:
+ // 1 - x^2/6 + x^4/120 - x^6/5040 + x^8/362880
+ constexpr double PI_D = 0x1.921fb54442d18p+1;
+ double frac_d = static_cast<double>(frac);
+ double x_pi_d = PI_D * frac_d;
+ double x_pi2_d = x_pi_d * x_pi_d;
+ constexpr double DC1 = -0x1.5555555555555p-3; // -1/6
+ constexpr double DC2 = 0x1.1111111111111p-7; // 1/120
+ constexpr double DC3 = -0x1.a01a01a01a01ap-13; // -1/5040
+ constexpr double DC4 = 0x1.71de3a556c734p-19; // 1/362880
+ double sin_pi_frac_d =
+ x_pi_d *
+ (1.0 +
+ x_pi2_d * (DC1 + x_pi2_d * (DC2 + x_pi2_d * (DC3 + x_pi2_d * DC4))));
----------------
lntue wrote:
Use Estrin's scheme and `fputil::multiply_add` to take advantage of fma instructions if available.
https://github.com/llvm/llvm-project/pull/199312
More information about the libc-commits
mailing list