[libc-commits] [libc] a239343 - [libc][math][c23] Add f16sqrtf C23 math function (#95251)
via libc-commits
libc-commits at lists.llvm.org
Thu Jun 13 09:57:28 PDT 2024
Author: OverMighty
Date: 2024-06-13T12:57:24-04:00
New Revision: a2393435217a0c8832b6b34e13977bb29d722d39
URL: https://github.com/llvm/llvm-project/commit/a2393435217a0c8832b6b34e13977bb29d722d39
DIFF: https://github.com/llvm/llvm-project/commit/a2393435217a0c8832b6b34e13977bb29d722d39.diff
LOG: [libc][math][c23] Add f16sqrtf C23 math function (#95251)
Part of #95250.
Added:
libc/src/math/f16sqrtf.h
libc/src/math/generic/f16sqrtf.cpp
libc/test/src/math/exhaustive/f16sqrtf_test.cpp
libc/test/src/math/smoke/f16sqrtf_test.cpp
Modified:
libc/config/linux/aarch64/entrypoints.txt
libc/config/linux/x86_64/entrypoints.txt
libc/docs/math/index.rst
libc/spec/stdc.td
libc/src/__support/FPUtil/generic/CMakeLists.txt
libc/src/__support/FPUtil/generic/sqrt.h
libc/src/math/CMakeLists.txt
libc/src/math/generic/CMakeLists.txt
libc/src/math/generic/acosf.cpp
libc/src/math/generic/acoshf.cpp
libc/src/math/generic/asinf.cpp
libc/src/math/generic/asinhf.cpp
libc/src/math/generic/hypotf.cpp
libc/src/math/generic/powf.cpp
libc/src/math/generic/sqrt.cpp
libc/src/math/generic/sqrtf.cpp
libc/src/math/generic/sqrtf128.cpp
libc/src/math/generic/sqrtl.cpp
libc/test/src/math/exhaustive/CMakeLists.txt
libc/test/src/math/exhaustive/exhaustive_test.h
libc/test/src/math/smoke/CMakeLists.txt
libc/test/src/math/smoke/SqrtTest.h
libc/utils/MPFRWrapper/CMakeLists.txt
libc/utils/MPFRWrapper/MPFRUtils.cpp
libc/utils/MPFRWrapper/MPFRUtils.h
Removed:
################################################################################
diff --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt
index db96a80051a8d..2b2d0985a8992 100644
--- a/libc/config/linux/aarch64/entrypoints.txt
+++ b/libc/config/linux/aarch64/entrypoints.txt
@@ -503,6 +503,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
libc.src.math.canonicalizef16
libc.src.math.ceilf16
libc.src.math.copysignf16
+ libc.src.math.f16sqrtf
libc.src.math.fabsf16
libc.src.math.fdimf16
libc.src.math.floorf16
diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt
index 355eaf33ace6d..2d36ca296c3a4 100644
--- a/libc/config/linux/x86_64/entrypoints.txt
+++ b/libc/config/linux/x86_64/entrypoints.txt
@@ -535,6 +535,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
libc.src.math.canonicalizef16
libc.src.math.ceilf16
libc.src.math.copysignf16
+ libc.src.math.f16sqrtf
libc.src.math.fabsf16
libc.src.math.fdimf16
libc.src.math.floorf16
diff --git a/libc/docs/math/index.rst b/libc/docs/math/index.rst
index d556885eda622..790786147c164 100644
--- a/libc/docs/math/index.rst
+++ b/libc/docs/math/index.rst
@@ -280,6 +280,8 @@ Higher Math Functions
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| fma | |check| | |check| | | | | 7.12.13.1 | F.10.10.1 |
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
+| f16sqrt | |check| | | | N/A | | 7.12.14.6 | F.10.11 |
++-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| fsqrt | N/A | | | N/A | | 7.12.14.6 | F.10.11 |
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| hypot | |check| | |check| | | | | 7.12.7.4 | F.10.4.4 |
diff --git a/libc/spec/stdc.td b/libc/spec/stdc.td
index b134ec00a7d7a..7c4135032a0b2 100644
--- a/libc/spec/stdc.td
+++ b/libc/spec/stdc.td
@@ -714,6 +714,8 @@ def StdC : StandardSpec<"stdc"> {
GuardedFunctionSpec<"totalorderf16", RetValSpec<IntType>, [ArgSpec<Float16Ptr>, ArgSpec<Float16Ptr>], "LIBC_TYPES_HAS_FLOAT16">,
GuardedFunctionSpec<"totalordermagf16", RetValSpec<IntType>, [ArgSpec<Float16Ptr>, ArgSpec<Float16Ptr>], "LIBC_TYPES_HAS_FLOAT16">,
+
+ GuardedFunctionSpec<"f16sqrtf", RetValSpec<Float16Type>, [ArgSpec<FloatType>], "LIBC_TYPES_HAS_FLOAT16">,
]
>;
diff --git a/libc/src/__support/FPUtil/generic/CMakeLists.txt b/libc/src/__support/FPUtil/generic/CMakeLists.txt
index 09eede1570962..595656e3e8d90 100644
--- a/libc/src/__support/FPUtil/generic/CMakeLists.txt
+++ b/libc/src/__support/FPUtil/generic/CMakeLists.txt
@@ -4,6 +4,7 @@ add_header_library(
sqrt.h
sqrt_80_bit_long_double.h
DEPENDS
+ libc.hdr.fenv_macros
libc.src.__support.common
libc.src.__support.CPP.bit
libc.src.__support.CPP.type_traits
diff --git a/libc/src/__support/FPUtil/generic/sqrt.h b/libc/src/__support/FPUtil/generic/sqrt.h
index 7e7600ba6502a..d6e894fdfe021 100644
--- a/libc/src/__support/FPUtil/generic/sqrt.h
+++ b/libc/src/__support/FPUtil/generic/sqrt.h
@@ -18,6 +18,8 @@
#include "src/__support/common.h"
#include "src/__support/uint128.h"
+#include "hdr/fenv_macros.h"
+
namespace LIBC_NAMESPACE {
namespace fputil {
@@ -64,40 +66,50 @@ LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) {
// Correctly rounded IEEE 754 SQRT for all rounding modes.
// Shift-and-add algorithm.
-template <typename T>
-LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
-
- if constexpr (internal::SpecialLongDouble<T>::VALUE) {
+template <typename OutType, typename InType>
+LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
+ cpp::is_floating_point_v<InType> &&
+ sizeof(OutType) <= sizeof(InType),
+ OutType>
+sqrt(InType x) {
+ if constexpr (internal::SpecialLongDouble<OutType>::VALUE &&
+ internal::SpecialLongDouble<InType>::VALUE) {
// Special 80-bit long double.
return x86::sqrt(x);
} else {
// IEEE floating points formats.
- using FPBits_t = typename fputil::FPBits<T>;
- using StorageType = typename FPBits_t::StorageType;
- constexpr StorageType ONE = StorageType(1) << FPBits_t::FRACTION_LEN;
- constexpr auto FLT_NAN = FPBits_t::quiet_nan().get_val();
-
- FPBits_t bits(x);
-
- if (bits == FPBits_t::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
+ using OutFPBits = typename fputil::FPBits<OutType>;
+ using OutStorageType = typename OutFPBits::StorageType;
+ using InFPBits = typename fputil::FPBits<InType>;
+ using InStorageType = typename InFPBits::StorageType;
+ constexpr InStorageType ONE = InStorageType(1) << InFPBits::FRACTION_LEN;
+ constexpr auto FLT_NAN = OutFPBits::quiet_nan().get_val();
+ constexpr int EXTRA_FRACTION_LEN =
+ InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
+ constexpr InStorageType EXTRA_FRACTION_MASK =
+ (InStorageType(1) << EXTRA_FRACTION_LEN) - 1;
+
+ InFPBits bits(x);
+
+ if (bits == InFPBits::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
// sqrt(+Inf) = +Inf
// sqrt(+0) = +0
// sqrt(-0) = -0
// sqrt(NaN) = NaN
// sqrt(-NaN) = -NaN
- return x;
+ return static_cast<OutType>(x);
} else if (bits.is_neg()) {
// sqrt(-Inf) = NaN
// sqrt(-x) = NaN
return FLT_NAN;
} else {
int x_exp = bits.get_exponent();
- StorageType x_mant = bits.get_mantissa();
+ InStorageType x_mant = bits.get_mantissa();
// Step 1a: Normalize denormal input and append hidden bit to the mantissa
if (bits.is_subnormal()) {
++x_exp; // let x_exp be the correct exponent of ONE bit.
- internal::normalize<T>(x_exp, x_mant);
+ internal::normalize<InType>(x_exp, x_mant);
} else {
x_mant |= ONE;
}
@@ -120,12 +132,13 @@ LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
// So the nth digit y_n of the mantissa of sqrt(x) can be found by:
// y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
// 0 otherwise.
- StorageType y = ONE;
- StorageType r = x_mant - ONE;
+ InStorageType y = ONE;
+ InStorageType r = x_mant - ONE;
- for (StorageType current_bit = ONE >> 1; current_bit; current_bit >>= 1) {
+ for (InStorageType current_bit = ONE >> 1; current_bit;
+ current_bit >>= 1) {
r <<= 1;
- StorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
+ InStorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
if (r >= tmp) {
r -= tmp;
y += current_bit;
@@ -133,34 +146,91 @@ LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
}
// We compute one more iteration in order to round correctly.
- bool lsb = static_cast<bool>(y & 1); // Least significant bit
- bool rb = false; // Round bit
+ bool lsb = (y & (InStorageType(1) << EXTRA_FRACTION_LEN)) !=
+ 0; // Least significant bit
+ bool rb = false; // Round bit
r <<= 2;
- StorageType tmp = (y << 2) + 1;
+ InStorageType tmp = (y << 2) + 1;
if (r >= tmp) {
r -= tmp;
rb = true;
}
+ bool sticky = false;
+
+ if constexpr (EXTRA_FRACTION_LEN > 0) {
+ sticky = rb || (y & EXTRA_FRACTION_MASK) != 0;
+ rb = (y & (InStorageType(1) << (EXTRA_FRACTION_LEN - 1))) != 0;
+ }
+
// Remove hidden bit and append the exponent field.
- x_exp = ((x_exp >> 1) + FPBits_t::EXP_BIAS);
+ x_exp = ((x_exp >> 1) + OutFPBits::EXP_BIAS);
+
+ OutStorageType y_out = static_cast<OutStorageType>(
+ ((y - ONE) >> EXTRA_FRACTION_LEN) |
+ (static_cast<OutStorageType>(x_exp) << OutFPBits::FRACTION_LEN));
+
+ if constexpr (EXTRA_FRACTION_LEN > 0) {
+ if (x_exp >= OutFPBits::MAX_BIASED_EXPONENT) {
+ switch (quick_get_round()) {
+ case FE_TONEAREST:
+ case FE_UPWARD:
+ return OutFPBits::inf().get_val();
+ default:
+ return OutFPBits::max_normal().get_val();
+ }
+ }
+
+ if (x_exp <
+ -OutFPBits::EXP_BIAS - OutFPBits::SIG_LEN + EXTRA_FRACTION_LEN) {
+ switch (quick_get_round()) {
+ case FE_UPWARD:
+ return OutFPBits::min_subnormal().get_val();
+ default:
+ return OutType(0.0);
+ }
+ }
- y = (y - ONE) |
- (static_cast<StorageType>(x_exp) << FPBits_t::FRACTION_LEN);
+ if (x_exp <= 0) {
+ int underflow_extra_fraction_len = EXTRA_FRACTION_LEN - x_exp + 1;
+ InStorageType underflow_extra_fraction_mask =
+ (InStorageType(1) << underflow_extra_fraction_len) - 1;
+
+ rb = (y & (InStorageType(1) << (underflow_extra_fraction_len - 1))) !=
+ 0;
+ OutStorageType subnormal_mant =
+ static_cast<OutStorageType>(y >> underflow_extra_fraction_len);
+ lsb = (subnormal_mant & 1) != 0;
+ sticky = sticky || (y & underflow_extra_fraction_mask) != 0;
+
+ switch (quick_get_round()) {
+ case FE_TONEAREST:
+ if (rb && (lsb || sticky))
+ ++subnormal_mant;
+ break;
+ case FE_UPWARD:
+ if (rb || sticky)
+ ++subnormal_mant;
+ break;
+ }
+
+ return cpp::bit_cast<OutType>(subnormal_mant);
+ }
+ }
switch (quick_get_round()) {
case FE_TONEAREST:
// Round to nearest, ties to even
if (rb && (lsb || (r != 0)))
- ++y;
+ ++y_out;
break;
case FE_UPWARD:
- if (rb || (r != 0))
- ++y;
+ if (rb || (r != 0) || sticky)
+ ++y_out;
break;
}
- return cpp::bit_cast<T>(y);
+ return cpp::bit_cast<OutType>(y_out);
}
}
}
diff --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt
index 2446c293b8ef5..df8e6c0b253da 100644
--- a/libc/src/math/CMakeLists.txt
+++ b/libc/src/math/CMakeLists.txt
@@ -99,6 +99,8 @@ add_math_entrypoint_object(exp10f)
add_math_entrypoint_object(expm1)
add_math_entrypoint_object(expm1f)
+add_math_entrypoint_object(f16sqrtf)
+
add_math_entrypoint_object(fabs)
add_math_entrypoint_object(fabsf)
add_math_entrypoint_object(fabsl)
diff --git a/libc/src/math/f16sqrtf.h b/libc/src/math/f16sqrtf.h
new file mode 100644
index 0000000000000..197ebe6db8016
--- /dev/null
+++ b/libc/src/math/f16sqrtf.h
@@ -0,0 +1,20 @@
+//===-- Implementation header for f16sqrtf ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_F16SQRTF_H
+#define LLVM_LIBC_SRC_MATH_F16SQRTF_H
+
+#include "src/__support/macros/properties/types.h"
+
+namespace LIBC_NAMESPACE {
+
+float16 f16sqrtf(float x);
+
+} // namespace LIBC_NAMESPACE
+
+#endif // LLVM_LIBC_SRC_MATH_F16SQRTF_H
diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt
index 673bef516b13d..f1f7d6c367be2 100644
--- a/libc/src/math/generic/CMakeLists.txt
+++ b/libc/src/math/generic/CMakeLists.txt
@@ -3601,3 +3601,16 @@ add_entrypoint_object(
COMPILE_OPTIONS
-O3
)
+
+add_entrypoint_object(
+ f16sqrtf
+ SRCS
+ f16sqrtf.cpp
+ HDRS
+ ../f16sqrtf.h
+ DEPENDS
+ libc.src.__support.macros.properties.types
+ libc.src.__support.FPUtil.sqrt
+ COMPILE_OPTIONS
+ -O3
+)
diff --git a/libc/src/math/generic/acosf.cpp b/libc/src/math/generic/acosf.cpp
index e6e28d43ef61f..f02edec267174 100644
--- a/libc/src/math/generic/acosf.cpp
+++ b/libc/src/math/generic/acosf.cpp
@@ -113,7 +113,7 @@ LLVM_LIBC_FUNCTION(float, acosf, (float x)) {
xbits.set_sign(Sign::POS);
double xd = static_cast<double>(xbits.get_val());
double u = fputil::multiply_add(-0.5, xd, 0.5);
- double cv = 2 * fputil::sqrt(u);
+ double cv = 2 * fputil::sqrt<double>(u);
double r3 = asin_eval(u);
double r = fputil::multiply_add(cv * u, r3, cv);
diff --git a/libc/src/math/generic/acoshf.cpp b/libc/src/math/generic/acoshf.cpp
index a4a75a7b04385..9422ec63e1ce2 100644
--- a/libc/src/math/generic/acoshf.cpp
+++ b/libc/src/math/generic/acoshf.cpp
@@ -66,8 +66,8 @@ LLVM_LIBC_FUNCTION(float, acoshf, (float x)) {
double x_d = static_cast<double>(x);
// acosh(x) = log(x + sqrt(x^2 - 1))
- return static_cast<float>(
- log_eval(x_d + fputil::sqrt(fputil::multiply_add(x_d, x_d, -1.0))));
+ return static_cast<float>(log_eval(
+ x_d + fputil::sqrt<double>(fputil::multiply_add(x_d, x_d, -1.0))));
}
} // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/asinf.cpp b/libc/src/math/generic/asinf.cpp
index d9133333d2561..c4afca493a713 100644
--- a/libc/src/math/generic/asinf.cpp
+++ b/libc/src/math/generic/asinf.cpp
@@ -144,7 +144,7 @@ LLVM_LIBC_FUNCTION(float, asinf, (float x)) {
double sign = SIGN[x_sign];
double xd = static_cast<double>(xbits.get_val());
double u = fputil::multiply_add(-0.5, xd, 0.5);
- double c1 = sign * (-2 * fputil::sqrt(u));
+ double c1 = sign * (-2 * fputil::sqrt<double>(u));
double c2 = fputil::multiply_add(sign, M_MATH_PI_2, c1);
double c3 = c1 * u;
diff --git a/libc/src/math/generic/asinhf.cpp b/libc/src/math/generic/asinhf.cpp
index 6e351786e3eca..82dc2a31ebc22 100644
--- a/libc/src/math/generic/asinhf.cpp
+++ b/libc/src/math/generic/asinhf.cpp
@@ -97,9 +97,9 @@ LLVM_LIBC_FUNCTION(float, asinhf, (float x)) {
// asinh(x) = log(x + sqrt(x^2 + 1))
return static_cast<float>(
- x_sign *
- log_eval(fputil::multiply_add(
- x_d, x_sign, fputil::sqrt(fputil::multiply_add(x_d, x_d, 1.0)))));
+ x_sign * log_eval(fputil::multiply_add(
+ x_d, x_sign,
+ fputil::sqrt<double>(fputil::multiply_add(x_d, x_d, 1.0)))));
}
} // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/f16sqrtf.cpp b/libc/src/math/generic/f16sqrtf.cpp
new file mode 100644
index 0000000000000..1f7ee2df29e86
--- /dev/null
+++ b/libc/src/math/generic/f16sqrtf.cpp
@@ -0,0 +1,19 @@
+//===-- Implementation of f16sqrtf function -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/f16sqrtf.h"
+#include "src/__support/FPUtil/sqrt.h"
+#include "src/__support/common.h"
+
+namespace LIBC_NAMESPACE {
+
+LLVM_LIBC_FUNCTION(float16, f16sqrtf, (float x)) {
+ return fputil::sqrt<float16>(x);
+}
+
+} // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/hypotf.cpp b/libc/src/math/generic/hypotf.cpp
index ffbf706aefaf6..b09d09ad7f9c9 100644
--- a/libc/src/math/generic/hypotf.cpp
+++ b/libc/src/math/generic/hypotf.cpp
@@ -42,7 +42,7 @@ LLVM_LIBC_FUNCTION(float, hypotf, (float x, float y)) {
double err = (x_sq >= y_sq) ? (sum_sq - x_sq) - y_sq : (sum_sq - y_sq) - x_sq;
// Take sqrt in double precision.
- DoubleBits result(fputil::sqrt(sum_sq));
+ DoubleBits result(fputil::sqrt<double>(sum_sq));
if (!DoubleBits(sum_sq).is_inf_or_nan()) {
// Correct rounding.
diff --git a/libc/src/math/generic/powf.cpp b/libc/src/math/generic/powf.cpp
index 59efc3f424c76..13c04240f59c2 100644
--- a/libc/src/math/generic/powf.cpp
+++ b/libc/src/math/generic/powf.cpp
@@ -562,7 +562,7 @@ LLVM_LIBC_FUNCTION(float, powf, (float x, float y)) {
switch (y_u) {
case 0x3f00'0000: // y = 0.5f
// pow(x, 1/2) = sqrt(x)
- return fputil::sqrt(x);
+ return fputil::sqrt<float>(x);
case 0x3f80'0000: // y = 1.0f
return x;
case 0x4000'0000: // y = 2.0f
diff --git a/libc/src/math/generic/sqrt.cpp b/libc/src/math/generic/sqrt.cpp
index b4d02785dcb43..f33b0a2cdcf74 100644
--- a/libc/src/math/generic/sqrt.cpp
+++ b/libc/src/math/generic/sqrt.cpp
@@ -12,6 +12,6 @@
namespace LIBC_NAMESPACE {
-LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt(x); }
+LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt<double>(x); }
} // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/sqrtf.cpp b/libc/src/math/generic/sqrtf.cpp
index bc74252295b3a..26a53e9077c1c 100644
--- a/libc/src/math/generic/sqrtf.cpp
+++ b/libc/src/math/generic/sqrtf.cpp
@@ -12,6 +12,6 @@
namespace LIBC_NAMESPACE {
-LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) { return fputil::sqrt(x); }
+LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) { return fputil::sqrt<float>(x); }
} // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/sqrtf128.cpp b/libc/src/math/generic/sqrtf128.cpp
index 0196c3e0a96ae..70e28ddb692d4 100644
--- a/libc/src/math/generic/sqrtf128.cpp
+++ b/libc/src/math/generic/sqrtf128.cpp
@@ -12,6 +12,8 @@
namespace LIBC_NAMESPACE {
-LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) { return fputil::sqrt(x); }
+LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
+ return fputil::sqrt<float128>(x);
+}
} // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/sqrtl.cpp b/libc/src/math/generic/sqrtl.cpp
index b2aaa279f9c2a..9f0cc87853823 100644
--- a/libc/src/math/generic/sqrtl.cpp
+++ b/libc/src/math/generic/sqrtl.cpp
@@ -13,7 +13,7 @@
namespace LIBC_NAMESPACE {
LLVM_LIBC_FUNCTION(long double, sqrtl, (long double x)) {
- return fputil::sqrt(x);
+ return fputil::sqrt<long double>(x);
}
} // namespace LIBC_NAMESPACE
diff --git a/libc/test/src/math/exhaustive/CMakeLists.txt b/libc/test/src/math/exhaustive/CMakeLists.txt
index 938e519aff084..34df8720ed4db 100644
--- a/libc/test/src/math/exhaustive/CMakeLists.txt
+++ b/libc/test/src/math/exhaustive/CMakeLists.txt
@@ -420,3 +420,18 @@ add_fp_unittest(
LINK_LIBRARIES
-lpthread
)
+
+add_fp_unittest(
+ f16sqrtf_test
+ NO_RUN_POSTBUILD
+ NEED_MPFR
+ SUITE
+ libc_math_exhaustive_tests
+ SRCS
+ f16sqrtf_test.cpp
+ DEPENDS
+ .exhaustive_test
+ libc.src.math.f16sqrtf
+ LINK_LIBRARIES
+ -lpthread
+)
diff --git a/libc/test/src/math/exhaustive/exhaustive_test.h b/libc/test/src/math/exhaustive/exhaustive_test.h
index c4ae382688a03..13e272783250b 100644
--- a/libc/test/src/math/exhaustive/exhaustive_test.h
+++ b/libc/test/src/math/exhaustive/exhaustive_test.h
@@ -35,16 +35,16 @@
// LlvmLibcUnaryOpExhaustiveMathTest<FloatType, Op, Func>.
namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
-template <typename T> using UnaryOp = T(T);
+template <typename OutType, typename InType = OutType>
+using UnaryOp = OutType(InType);
-template <typename T, mpfr::Operation Op, UnaryOp<T> Func>
+template <typename OutType, typename InType, mpfr::Operation Op,
+ UnaryOp<OutType, InType> Func>
struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
- using FloatType = T;
+ using FloatType = InType;
using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
using StorageType = typename FPBits::StorageType;
- static constexpr UnaryOp<FloatType> *FUNC = Func;
-
// Check in a range, return the number of failures.
uint64_t check(StorageType start, StorageType stop,
mpfr::RoundingMode rounding) {
@@ -57,11 +57,11 @@ struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
FPBits xbits(bits);
FloatType x = xbits.get_val();
bool correct =
- TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, FUNC(x), 0.5, rounding);
+ TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, Func(x), 0.5, rounding);
failed += (!correct);
// Uncomment to print out failed values.
// if (!correct) {
- // TEST_MPFR_MATCH(Op::Operation, x, Op::func(x), 0.5, rounding);
+ // EXPECT_MPFR_MATCH_ROUNDING(Op, x, Func(x), 0.5, rounding);
// }
} while (bits++ < stop);
return failed;
@@ -169,4 +169,9 @@ struct LlvmLibcExhaustiveMathTest
template <typename FloatType, mpfr::Operation Op, UnaryOp<FloatType> Func>
using LlvmLibcUnaryOpExhaustiveMathTest =
- LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, Op, Func>>;
+ LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, FloatType, Op, Func>>;
+
+template <typename OutType, typename InType, mpfr::Operation Op,
+ UnaryOp<OutType, InType> Func>
+using LlvmLibcUnaryNarrowingOpExhaustiveMathTest =
+ LlvmLibcExhaustiveMathTest<UnaryOpChecker<OutType, InType, Op, Func>>;
diff --git a/libc/test/src/math/exhaustive/f16sqrtf_test.cpp b/libc/test/src/math/exhaustive/f16sqrtf_test.cpp
new file mode 100644
index 0000000000000..3a42ff8e0725d
--- /dev/null
+++ b/libc/test/src/math/exhaustive/f16sqrtf_test.cpp
@@ -0,0 +1,25 @@
+//===-- Exhaustive test for f16sqrtf --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "exhaustive_test.h"
+#include "src/math/f16sqrtf.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
+
+using LlvmLibcF16sqrtfExhaustiveTest =
+ LlvmLibcUnaryNarrowingOpExhaustiveMathTest<
+ float16, float, mpfr::Operation::Sqrt, LIBC_NAMESPACE::f16sqrtf>;
+
+// Range: [0, Inf];
+static constexpr uint32_t POS_START = 0x0000'0000U;
+static constexpr uint32_t POS_STOP = 0x7f80'0000U;
+
+TEST_F(LlvmLibcF16sqrtfExhaustiveTest, PostiveRange) {
+ test_full_range_all_roundings(POS_START, POS_STOP);
+}
diff --git a/libc/test/src/math/smoke/CMakeLists.txt b/libc/test/src/math/smoke/CMakeLists.txt
index 68cd412b14e9d..3bb87d2b0d0f3 100644
--- a/libc/test/src/math/smoke/CMakeLists.txt
+++ b/libc/test/src/math/smoke/CMakeLists.txt
@@ -2504,9 +2504,10 @@ add_fp_unittest(
libc-math-smoke-tests
SRCS
sqrtf_test.cpp
+ HDRS
+ SqrtTest.h
DEPENDS
libc.src.math.sqrtf
- libc.src.__support.FPUtil.fp_bits
)
add_fp_unittest(
@@ -2515,9 +2516,10 @@ add_fp_unittest(
libc-math-smoke-tests
SRCS
sqrt_test.cpp
+ HDRS
+ SqrtTest.h
DEPENDS
libc.src.math.sqrt
- libc.src.__support.FPUtil.fp_bits
)
add_fp_unittest(
@@ -2526,9 +2528,10 @@ add_fp_unittest(
libc-math-smoke-tests
SRCS
sqrtl_test.cpp
+ HDRS
+ SqrtTest.h
DEPENDS
libc.src.math.sqrtl
- libc.src.__support.FPUtil.fp_bits
)
add_fp_unittest(
@@ -2537,9 +2540,10 @@ add_fp_unittest(
libc-math-smoke-tests
SRCS
sqrtf128_test.cpp
+ HDRS
+ SqrtTest.h
DEPENDS
libc.src.math.sqrtf128
- libc.src.__support.FPUtil.fp_bits
)
add_fp_unittest(
@@ -2548,9 +2552,9 @@ add_fp_unittest(
libc-math-smoke-tests
SRCS
generic_sqrtf_test.cpp
+ HDRS
+ SqrtTest.h
DEPENDS
- libc.src.math.sqrtf
- libc.src.__support.FPUtil.fp_bits
libc.src.__support.FPUtil.generic.sqrt
COMPILE_OPTIONS
-O3
@@ -2562,9 +2566,9 @@ add_fp_unittest(
libc-math-smoke-tests
SRCS
generic_sqrt_test.cpp
+ HDRS
+ SqrtTest.h
DEPENDS
- libc.src.math.sqrt
- libc.src.__support.FPUtil.fp_bits
libc.src.__support.FPUtil.generic.sqrt
COMPILE_OPTIONS
-O3
@@ -2576,9 +2580,9 @@ add_fp_unittest(
libc-math-smoke-tests
SRCS
generic_sqrtl_test.cpp
+ HDRS
+ SqrtTest.h
DEPENDS
- libc.src.math.sqrtl
- libc.src.__support.FPUtil.fp_bits
libc.src.__support.FPUtil.generic.sqrt
COMPILE_OPTIONS
-O3
@@ -2590,9 +2594,9 @@ add_fp_unittest(
libc-math-smoke-tests
SRCS
generic_sqrtf128_test.cpp
+ HDRS
+ SqrtTest.h
DEPENDS
- libc.src.math.sqrtf128
- libc.src.__support.FPUtil.fp_bits
libc.src.__support.FPUtil.generic.sqrt
COMPILE_OPTIONS
-O3
@@ -3543,3 +3547,15 @@ add_fp_unittest(
DEPENDS
libc.src.math.totalordermagf16
)
+
+add_fp_unittest(
+ f16sqrtf_test
+ SUITE
+ libc-math-smoke-tests
+ SRCS
+ f16sqrtf_test.cpp
+ HDRS
+ SqrtTest.h
+ DEPENDS
+ libc.src.math.f16sqrtf
+)
diff --git a/libc/test/src/math/smoke/SqrtTest.h b/libc/test/src/math/smoke/SqrtTest.h
index 8afacaf01ae42..ce9f2f85b4604 100644
--- a/libc/test/src/math/smoke/SqrtTest.h
+++ b/libc/test/src/math/smoke/SqrtTest.h
@@ -6,37 +6,35 @@
//
//===----------------------------------------------------------------------===//
-#include "src/__support/CPP/bit.h"
#include "test/UnitTest/FEnvSafeTest.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"
-#include "hdr/math_macros.h"
-
-template <typename T>
+template <typename OutType, typename InType>
class SqrtTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
- DECLARE_SPECIAL_CONSTANTS(T)
-
- static constexpr StorageType HIDDEN_BIT =
- StorageType(1) << LIBC_NAMESPACE::fputil::FPBits<T>::FRACTION_LEN;
+ DECLARE_SPECIAL_CONSTANTS(OutType)
public:
- typedef T (*SqrtFunc)(T);
+ typedef OutType (*SqrtFunc)(InType);
void test_special_numbers(SqrtFunc func) {
ASSERT_FP_EQ(aNaN, func(aNaN));
ASSERT_FP_EQ(inf, func(inf));
ASSERT_FP_EQ(aNaN, func(neg_inf));
- ASSERT_FP_EQ(0.0, func(0.0));
- ASSERT_FP_EQ(-0.0, func(-0.0));
- ASSERT_FP_EQ(aNaN, func(T(-1.0)));
- ASSERT_FP_EQ(T(1.0), func(T(1.0)));
- ASSERT_FP_EQ(T(2.0), func(T(4.0)));
- ASSERT_FP_EQ(T(3.0), func(T(9.0)));
+ ASSERT_FP_EQ(zero, func(zero));
+ ASSERT_FP_EQ(neg_zero, func(neg_zero));
+ ASSERT_FP_EQ(aNaN, func(InType(-1.0)));
+ ASSERT_FP_EQ(OutType(1.0), func(InType(1.0)));
+ ASSERT_FP_EQ(OutType(2.0), func(InType(4.0)));
+ ASSERT_FP_EQ(OutType(3.0), func(InType(9.0)));
}
};
#define LIST_SQRT_TESTS(T, func) \
- using LlvmLibcSqrtTest = SqrtTest<T>; \
+ using LlvmLibcSqrtTest = SqrtTest<T, T>; \
+ TEST_F(LlvmLibcSqrtTest, SpecialNumbers) { test_special_numbers(&func); }
+
+#define LIST_NARROWING_SQRT_TESTS(OutType, InType, func) \
+ using LlvmLibcSqrtTest = SqrtTest<OutType, InType>; \
TEST_F(LlvmLibcSqrtTest, SpecialNumbers) { test_special_numbers(&func); }
diff --git a/libc/test/src/math/smoke/f16sqrtf_test.cpp b/libc/test/src/math/smoke/f16sqrtf_test.cpp
new file mode 100644
index 0000000000000..36231aeb4184d
--- /dev/null
+++ b/libc/test/src/math/smoke/f16sqrtf_test.cpp
@@ -0,0 +1,13 @@
+//===-- Unittests for f16sqrtf --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "SqrtTest.h"
+
+#include "src/math/f16sqrtf.h"
+
+LIST_NARROWING_SQRT_TESTS(float16, float, LIBC_NAMESPACE::f16sqrtf)
diff --git a/libc/utils/MPFRWrapper/CMakeLists.txt b/libc/utils/MPFRWrapper/CMakeLists.txt
index 6af6fd7707041..e74b02204ed6f 100644
--- a/libc/utils/MPFRWrapper/CMakeLists.txt
+++ b/libc/utils/MPFRWrapper/CMakeLists.txt
@@ -7,6 +7,8 @@ if(LIBC_TESTS_CAN_USE_MPFR)
target_compile_options(libcMPFRWrapper PRIVATE -O3)
add_dependencies(
libcMPFRWrapper
+ libc.src.__support.CPP.array
+ libc.src.__support.CPP.stringstream
libc.src.__support.CPP.string_view
libc.src.__support.CPP.type_traits
libc.src.__support.FPUtil.fp_bits
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp
index 6918139fa83b7..100c6b1644b16 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp
@@ -8,8 +8,10 @@
#include "MPFRUtils.h"
+#include "src/__support/CPP/array.h"
#include "src/__support/CPP/string.h"
#include "src/__support/CPP/string_view.h"
+#include "src/__support/CPP/stringstream.h"
#include "src/__support/FPUtil/FPBits.h"
#include "src/__support/FPUtil/fpbits_str.h"
#include "src/__support/macros/properties/types.h"
@@ -755,39 +757,51 @@ ternary_operation_one_output(Operation op, InputType x, InputType y,
// to build the complete error messages before sending it to the outstream `OS`
// once at the end. This will stop the error messages from interleaving when
// the tests are running concurrently.
-template <typename T>
-void explain_unary_operation_single_output_error(Operation op, T input,
- T matchValue,
+template <typename InputType, typename OutputType>
+void explain_unary_operation_single_output_error(Operation op, InputType input,
+ OutputType matchValue,
double ulp_tolerance,
RoundingMode rounding) {
- unsigned int precision = get_precision<T>(ulp_tolerance);
+ unsigned int precision = get_precision<InputType>(ulp_tolerance);
MPFRNumber mpfrInput(input, precision);
MPFRNumber mpfr_result;
mpfr_result = unary_operation(op, input, precision, rounding);
MPFRNumber mpfrMatchValue(matchValue);
- tlog << "Match value not within tolerance value of MPFR result:\n"
- << " Input decimal: " << mpfrInput.str() << '\n';
- tlog << " Input bits: " << str(FPBits<T>(input)) << '\n';
- tlog << '\n' << " Match decimal: " << mpfrMatchValue.str() << '\n';
- tlog << " Match bits: " << str(FPBits<T>(matchValue)) << '\n';
- tlog << '\n' << " MPFR result: " << mpfr_result.str() << '\n';
- tlog << " MPFR rounded: " << str(FPBits<T>(mpfr_result.as<T>())) << '\n';
- tlog << '\n';
- tlog << " ULP error: "
- << mpfr_result.ulp_as_mpfr_number(matchValue).str() << '\n';
+ cpp::array<char, 1024> msg_buf;
+ cpp::StringStream msg(msg_buf);
+ msg << "Match value not within tolerance value of MPFR result:\n"
+ << " Input decimal: " << mpfrInput.str() << '\n';
+ msg << " Input bits: " << str(FPBits<InputType>(input)) << '\n';
+ msg << '\n' << " Match decimal: " << mpfrMatchValue.str() << '\n';
+ msg << " Match bits: " << str(FPBits<OutputType>(matchValue)) << '\n';
+ msg << '\n' << " MPFR result: " << mpfr_result.str() << '\n';
+ msg << " MPFR rounded: "
+ << str(FPBits<OutputType>(mpfr_result.as<OutputType>())) << '\n';
+ msg << '\n';
+ msg << " ULP error: " << mpfr_result.ulp_as_mpfr_number(matchValue).str()
+ << '\n';
+ if (msg.overflow())
+ __builtin_unreachable();
+ tlog << msg.str();
}
-template void explain_unary_operation_single_output_error<float>(Operation op,
- float, float,
- double,
- RoundingMode);
-template void explain_unary_operation_single_output_error<double>(
- Operation op, double, double, double, RoundingMode);
-template void explain_unary_operation_single_output_error<long double>(
- Operation op, long double, long double, double, RoundingMode);
+template void explain_unary_operation_single_output_error(Operation op, float,
+ float, double,
+ RoundingMode);
+template void explain_unary_operation_single_output_error(Operation op, double,
+ double, double,
+ RoundingMode);
+template void explain_unary_operation_single_output_error(Operation op,
+ long double,
+ long double, double,
+ RoundingMode);
#ifdef LIBC_TYPES_HAS_FLOAT16
-template void explain_unary_operation_single_output_error<float16>(
- Operation op, float16, float16, double, RoundingMode);
+template void explain_unary_operation_single_output_error(Operation op, float16,
+ float16, double,
+ RoundingMode);
+template void explain_unary_operation_single_output_error(Operation op, float,
+ float16, double,
+ RoundingMode);
#endif
template <typename T>
@@ -949,29 +963,30 @@ template void explain_ternary_operation_one_output_error<long double>(
Operation, const TernaryInput<long double> &, long double, double,
RoundingMode);
-template <typename T>
-bool compare_unary_operation_single_output(Operation op, T input, T libc_result,
+template <typename InputType, typename OutputType>
+bool compare_unary_operation_single_output(Operation op, InputType input,
+ OutputType libc_result,
double ulp_tolerance,
RoundingMode rounding) {
- unsigned int precision = get_precision<T>(ulp_tolerance);
+ unsigned int precision = get_precision<InputType>(ulp_tolerance);
MPFRNumber mpfr_result;
mpfr_result = unary_operation(op, input, precision, rounding);
double ulp = mpfr_result.ulp(libc_result);
return (ulp <= ulp_tolerance);
}
-template bool compare_unary_operation_single_output<float>(Operation, float,
- float, double,
- RoundingMode);
-template bool compare_unary_operation_single_output<double>(Operation, double,
- double, double,
- RoundingMode);
-template bool compare_unary_operation_single_output<long double>(
- Operation, long double, long double, double, RoundingMode);
+template bool compare_unary_operation_single_output(Operation, float, float,
+ double, RoundingMode);
+template bool compare_unary_operation_single_output(Operation, double, double,
+ double, RoundingMode);
+template bool compare_unary_operation_single_output(Operation, long double,
+ long double, double,
+ RoundingMode);
#ifdef LIBC_TYPES_HAS_FLOAT16
-template bool compare_unary_operation_single_output<float16>(Operation, float16,
- float16, double,
- RoundingMode);
+template bool compare_unary_operation_single_output(Operation, float16, float16,
+ double, RoundingMode);
+template bool compare_unary_operation_single_output(Operation, float, float16,
+ double, RoundingMode);
#endif
template <typename T>
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h
index d2f73e2628e16..805678b96c2ef 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.h
+++ b/libc/utils/MPFRWrapper/MPFRUtils.h
@@ -129,8 +129,9 @@ struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
static constexpr bool VALUE = cpp::is_floating_point_v<T>;
};
-template <typename T>
-bool compare_unary_operation_single_output(Operation op, T input, T libc_output,
+template <typename InputType, typename OutputType>
+bool compare_unary_operation_single_output(Operation op, InputType input,
+ OutputType libc_output,
double ulp_tolerance,
RoundingMode rounding);
template <typename T>
@@ -157,9 +158,9 @@ bool compare_ternary_operation_one_output(Operation op,
T libc_output, double ulp_tolerance,
RoundingMode rounding);
-template <typename T>
-void explain_unary_operation_single_output_error(Operation op, T input,
- T match_value,
+template <typename InputType, typename OutputType>
+void explain_unary_operation_single_output_error(Operation op, InputType input,
+ OutputType match_value,
double ulp_tolerance,
RoundingMode rounding);
template <typename T>
@@ -212,7 +213,7 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
bool is_silent() const override { return silent; }
private:
- template <typename T> bool match(T in, T out) {
+ template <typename T, typename U> bool match(T in, U out) {
return compare_unary_operation_single_output(op, in, out, ulp_tolerance,
rounding);
}
@@ -238,7 +239,7 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
rounding);
}
- template <typename T> void explain_error(T in, T out) {
+ template <typename T, typename U> void explain_error(T in, U out) {
explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
rounding);
}
@@ -271,6 +272,12 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
// types.
template <Operation op, typename InputType, typename OutputType>
constexpr bool is_valid_operation() {
+ constexpr bool IS_NARROWING_OP = op == Operation::Sqrt &&
+ cpp::is_floating_point_v<InputType> &&
+ cpp::is_floating_point_v<OutputType> &&
+ sizeof(OutputType) <= sizeof(InputType);
+ if (IS_NARROWING_OP)
+ return true;
return (Operation::BeginUnaryOperationsSingleOutput < op &&
op < Operation::EndUnaryOperationsSingleOutput &&
cpp::is_same_v<InputType, OutputType> &&
More information about the libc-commits
mailing list