[libc-commits] [libc] [libc][math][c23] Add f16sqrtf C23 math function (PR #95251)

via libc-commits libc-commits at lists.llvm.org
Thu Jun 13 05:53:56 PDT 2024


https://github.com/overmighty updated https://github.com/llvm/llvm-project/pull/95251

>From 3d261bafc7586ba0b8a5c2d44bdd44282f632ae2 Mon Sep 17 00:00:00 2001
From: OverMighty <its.overmighty at gmail.com>
Date: Tue, 11 Jun 2024 22:02:19 +0200
Subject: [PATCH 1/8] [libc][math][c23] Add f16sqrtf C23 math function

---
 libc/config/linux/aarch64/entrypoints.txt  |   1 +
 libc/config/linux/x86_64/entrypoints.txt   |   1 +
 libc/docs/math/index.rst                   |   2 +
 libc/spec/stdc.td                          |   2 +
 libc/src/__support/FPUtil/generic/sqrt.h   | 141 ++++++++++++++++-----
 libc/src/__support/FPUtil/sqrt.h           |   4 +-
 libc/src/math/CMakeLists.txt               |   2 +
 libc/src/math/f16sqrtf.h                   |  20 +++
 libc/src/math/generic/CMakeLists.txt       |  12 ++
 libc/src/math/generic/f16sqrtf.cpp         |  19 +++
 libc/src/math/generic/sqrt.cpp             |   2 +-
 libc/src/math/generic/sqrtf.cpp            |   2 +-
 libc/src/math/generic/sqrtf128.cpp         |   4 +-
 libc/src/math/generic/sqrtl.cpp            |   2 +-
 libc/test/src/math/smoke/CMakeLists.txt    |  11 ++
 libc/test/src/math/smoke/SqrtTest.h        |  28 ++--
 libc/test/src/math/smoke/f16sqrtf_test.cpp |  13 ++
 libc/test/src/math/smoke/sqrt_test.cpp     |   2 +-
 libc/test/src/math/smoke/sqrtf128_test.cpp |   2 +-
 libc/test/src/math/smoke/sqrtf_test.cpp    |   2 +-
 libc/test/src/math/smoke/sqrtl_test.cpp    |   2 +-
 21 files changed, 219 insertions(+), 55 deletions(-)
 create mode 100644 libc/src/math/f16sqrtf.h
 create mode 100644 libc/src/math/generic/f16sqrtf.cpp
 create mode 100644 libc/test/src/math/smoke/f16sqrtf_test.cpp

diff --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt
index db96a80051a8d..2b2d0985a8992 100644
--- a/libc/config/linux/aarch64/entrypoints.txt
+++ b/libc/config/linux/aarch64/entrypoints.txt
@@ -503,6 +503,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
     libc.src.math.canonicalizef16
     libc.src.math.ceilf16
     libc.src.math.copysignf16
+    libc.src.math.f16sqrtf
     libc.src.math.fabsf16
     libc.src.math.fdimf16
     libc.src.math.floorf16
diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt
index 355eaf33ace6d..2d36ca296c3a4 100644
--- a/libc/config/linux/x86_64/entrypoints.txt
+++ b/libc/config/linux/x86_64/entrypoints.txt
@@ -535,6 +535,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
     libc.src.math.canonicalizef16
     libc.src.math.ceilf16
     libc.src.math.copysignf16
+    libc.src.math.f16sqrtf
     libc.src.math.fabsf16
     libc.src.math.fdimf16
     libc.src.math.floorf16
diff --git a/libc/docs/math/index.rst b/libc/docs/math/index.rst
index d556885eda622..8243b14ff4786 100644
--- a/libc/docs/math/index.rst
+++ b/libc/docs/math/index.rst
@@ -282,6 +282,8 @@ Higher Math Functions
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | fsqrt     | N/A              |                 |                        | N/A                  |                        | 7.12.14.6              | F.10.11                    |
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
+| f16sqrt   | |check|          |                 |                        | N/A                  |                        | 7.12.14.6              | F.10.11                    |
++-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | hypot     | |check|          | |check|         |                        |                      |                        | 7.12.7.4               | F.10.4.4                   |
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | lgamma    |                  |                 |                        |                      |                        | 7.12.8.3               | F.10.5.3                   |
diff --git a/libc/spec/stdc.td b/libc/spec/stdc.td
index b134ec00a7d7a..7c4135032a0b2 100644
--- a/libc/spec/stdc.td
+++ b/libc/spec/stdc.td
@@ -714,6 +714,8 @@ def StdC : StandardSpec<"stdc"> {
           GuardedFunctionSpec<"totalorderf16", RetValSpec<IntType>, [ArgSpec<Float16Ptr>, ArgSpec<Float16Ptr>], "LIBC_TYPES_HAS_FLOAT16">,
 
           GuardedFunctionSpec<"totalordermagf16", RetValSpec<IntType>, [ArgSpec<Float16Ptr>, ArgSpec<Float16Ptr>], "LIBC_TYPES_HAS_FLOAT16">,
+
+          GuardedFunctionSpec<"f16sqrtf", RetValSpec<Float16Type>, [ArgSpec<FloatType>], "LIBC_TYPES_HAS_FLOAT16">,
       ]
   >;
 
diff --git a/libc/src/__support/FPUtil/generic/sqrt.h b/libc/src/__support/FPUtil/generic/sqrt.h
index 7e7600ba6502a..4c95053217228 100644
--- a/libc/src/__support/FPUtil/generic/sqrt.h
+++ b/libc/src/__support/FPUtil/generic/sqrt.h
@@ -17,6 +17,7 @@
 #include "src/__support/FPUtil/rounding_mode.h"
 #include "src/__support/common.h"
 #include "src/__support/uint128.h"
+#include <fenv.h>
 
 namespace LIBC_NAMESPACE {
 namespace fputil {
@@ -64,40 +65,50 @@ LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) {
 
 // Correctly rounded IEEE 754 SQRT for all rounding modes.
 // Shift-and-add algorithm.
-template <typename T>
-LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
-
-  if constexpr (internal::SpecialLongDouble<T>::VALUE) {
+template <typename OutType, typename InType>
+LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
+                                 cpp::is_floating_point_v<InType> &&
+                                 sizeof(OutType) <= sizeof(InType),
+                             OutType>
+sqrt(InType x) {
+  if constexpr (internal::SpecialLongDouble<OutType>::VALUE &&
+                internal::SpecialLongDouble<InType>::VALUE) {
     // Special 80-bit long double.
     return x86::sqrt(x);
   } else {
     // IEEE floating points formats.
-    using FPBits_t = typename fputil::FPBits<T>;
-    using StorageType = typename FPBits_t::StorageType;
-    constexpr StorageType ONE = StorageType(1) << FPBits_t::FRACTION_LEN;
-    constexpr auto FLT_NAN = FPBits_t::quiet_nan().get_val();
-
-    FPBits_t bits(x);
-
-    if (bits == FPBits_t::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
+    using OutFPBits = typename fputil::FPBits<OutType>;
+    using OutStorageType = typename OutFPBits::StorageType;
+    using InFPBits = typename fputil::FPBits<InType>;
+    using InStorageType = typename InFPBits::StorageType;
+    constexpr InStorageType ONE = InStorageType(1) << InFPBits::FRACTION_LEN;
+    constexpr auto FLT_NAN = OutFPBits::quiet_nan().get_val();
+    constexpr int EXTRA_FRACTION_LEN =
+        InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
+    constexpr InStorageType EXTRA_FRACTION_MASK =
+        (InStorageType(1) << EXTRA_FRACTION_LEN) - 1;
+
+    InFPBits bits(x);
+
+    if (bits == InFPBits::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
       // sqrt(+Inf) = +Inf
       // sqrt(+0) = +0
       // sqrt(-0) = -0
       // sqrt(NaN) = NaN
       // sqrt(-NaN) = -NaN
-      return x;
+      return static_cast<OutType>(x);
     } else if (bits.is_neg()) {
       // sqrt(-Inf) = NaN
       // sqrt(-x) = NaN
       return FLT_NAN;
     } else {
       int x_exp = bits.get_exponent();
-      StorageType x_mant = bits.get_mantissa();
+      InStorageType x_mant = bits.get_mantissa();
 
       // Step 1a: Normalize denormal input and append hidden bit to the mantissa
       if (bits.is_subnormal()) {
         ++x_exp; // let x_exp be the correct exponent of ONE bit.
-        internal::normalize<T>(x_exp, x_mant);
+        internal::normalize<InType>(x_exp, x_mant);
       } else {
         x_mant |= ONE;
       }
@@ -120,12 +131,13 @@ LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
       // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
       //   y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
       //         0 otherwise.
-      StorageType y = ONE;
-      StorageType r = x_mant - ONE;
+      InStorageType y = ONE;
+      InStorageType r = x_mant - ONE;
 
-      for (StorageType current_bit = ONE >> 1; current_bit; current_bit >>= 1) {
+      for (InStorageType current_bit = ONE >> 1; current_bit;
+           current_bit >>= 1) {
         r <<= 1;
-        StorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
+        InStorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
         if (r >= tmp) {
           r -= tmp;
           y += current_bit;
@@ -133,34 +145,105 @@ LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
       }
 
       // We compute one more iteration in order to round correctly.
-      bool lsb = static_cast<bool>(y & 1); // Least significant bit
-      bool rb = false;                     // Round bit
+      bool lsb = (y & (InStorageType(1) << EXTRA_FRACTION_LEN)) !=
+                 0;    // Least significant bit
+      bool rb = false; // Round bit
       r <<= 2;
-      StorageType tmp = (y << 2) + 1;
+      InStorageType tmp = (y << 2) + 1;
       if (r >= tmp) {
         r -= tmp;
         rb = true;
       }
 
+      bool sticky = false;
+
+      if constexpr (EXTRA_FRACTION_LEN > 0) {
+        sticky = rb || (y & EXTRA_FRACTION_MASK) != 0;
+        rb = (y & (InStorageType(1) << (EXTRA_FRACTION_LEN - 1))) != 0;
+      }
+
       // Remove hidden bit and append the exponent field.
-      x_exp = ((x_exp >> 1) + FPBits_t::EXP_BIAS);
+      x_exp = ((x_exp >> 1) + OutFPBits::EXP_BIAS);
+
+      OutStorageType y_out = static_cast<OutStorageType>(
+          ((y - ONE) >> EXTRA_FRACTION_LEN) |
+          (static_cast<OutStorageType>(x_exp) << OutFPBits::FRACTION_LEN));
+
+      if constexpr (EXTRA_FRACTION_LEN > 0) {
+        if (x_exp >= OutFPBits::MAX_BIASED_EXPONENT) {
+          switch (quick_get_round()) {
+          case FE_TONEAREST:
+          case FE_UPWARD:
+            return OutFPBits::inf().get_val();
+          default:
+            return OutFPBits::max_normal().get_val();
+          }
+        }
+
+        if (x_exp == OutFPBits::MAX_BIASED_EXPONENT - 1 &&
+            y == OutFPBits::max_normal().uintval() && (rb || sticky)) {
+          switch (quick_get_round()) {
+          case FE_TONEAREST:
+            if (rb)
+              return OutFPBits::inf().get_val();
+            return OutFPBits::max_normal().get_val();
+          case FE_UPWARD:
+            return OutFPBits::inf().get_val();
+          default:
+            return OutFPBits::max_normal().get_val();
+          }
+        }
 
-      y = (y - ONE) |
-          (static_cast<StorageType>(x_exp) << FPBits_t::FRACTION_LEN);
+        if (x_exp <
+            -OutFPBits::EXP_BIAS - OutFPBits::SIG_LEN + EXTRA_FRACTION_LEN) {
+          switch (quick_get_round()) {
+          case FE_UPWARD:
+            return OutFPBits::min_subnormal().get_val();
+          default:
+            return OutType(0.0);
+          }
+        }
+
+        if (x_exp <= 0) {
+          int underflow_extra_fraction_len = EXTRA_FRACTION_LEN - x_exp + 1;
+          InStorageType underflow_extra_fraction_mask =
+              (InStorageType(1) << underflow_extra_fraction_len) - 1;
+
+          rb = (y & (InStorageType(1) << (underflow_extra_fraction_len - 1))) !=
+               0;
+          OutStorageType subnormal_mant =
+              static_cast<OutStorageType>(y >> underflow_extra_fraction_len);
+          lsb = (subnormal_mant & 1) != 0;
+          sticky = sticky || (y & underflow_extra_fraction_mask) != 0;
+
+          switch (quick_get_round()) {
+          case FE_TONEAREST:
+            if (rb && (lsb || sticky))
+              ++subnormal_mant;
+            break;
+          case FE_UPWARD:
+            if (rb || sticky)
+              ++subnormal_mant;
+            break;
+          }
+
+          return cpp::bit_cast<OutType>(subnormal_mant);
+        }
+      }
 
       switch (quick_get_round()) {
       case FE_TONEAREST:
         // Round to nearest, ties to even
         if (rb && (lsb || (r != 0)))
-          ++y;
+          ++y_out;
         break;
       case FE_UPWARD:
-        if (rb || (r != 0))
-          ++y;
+        if (rb || (r != 0) || sticky)
+          ++y_out;
         break;
       }
 
-      return cpp::bit_cast<T>(y);
+      return cpp::bit_cast<OutType>(y_out);
     }
   }
 }
diff --git a/libc/src/__support/FPUtil/sqrt.h b/libc/src/__support/FPUtil/sqrt.h
index eb86ddfa89d8e..d9c30c586bb0d 100644
--- a/libc/src/__support/FPUtil/sqrt.h
+++ b/libc/src/__support/FPUtil/sqrt.h
@@ -13,7 +13,9 @@
 #include "src/__support/macros/properties/cpu_features.h"
 
 #if defined(LIBC_TARGET_ARCH_IS_X86_64) && defined(LIBC_TARGET_CPU_HAS_SSE2)
-#include "x86_64/sqrt.h"
+// #include "x86_64/sqrt.h"
+// TODO
+#include "generic/sqrt.h"
 #elif defined(LIBC_TARGET_ARCH_IS_AARCH64)
 #include "aarch64/sqrt.h"
 #elif defined(LIBC_TARGET_ARCH_IS_ANY_RISCV)
diff --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt
index 2446c293b8ef5..df8e6c0b253da 100644
--- a/libc/src/math/CMakeLists.txt
+++ b/libc/src/math/CMakeLists.txt
@@ -99,6 +99,8 @@ add_math_entrypoint_object(exp10f)
 add_math_entrypoint_object(expm1)
 add_math_entrypoint_object(expm1f)
 
+add_math_entrypoint_object(f16sqrtf)
+
 add_math_entrypoint_object(fabs)
 add_math_entrypoint_object(fabsf)
 add_math_entrypoint_object(fabsl)
diff --git a/libc/src/math/f16sqrtf.h b/libc/src/math/f16sqrtf.h
new file mode 100644
index 0000000000000..197ebe6db8016
--- /dev/null
+++ b/libc/src/math/f16sqrtf.h
@@ -0,0 +1,20 @@
+//===-- Implementation header for f16sqrtf ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_F16SQRTF_H
+#define LLVM_LIBC_SRC_MATH_F16SQRTF_H
+
+#include "src/__support/macros/properties/types.h"
+
+namespace LIBC_NAMESPACE {
+
+float16 f16sqrtf(float x);
+
+} // namespace LIBC_NAMESPACE
+
+#endif // LLVM_LIBC_SRC_MATH_F16SQRTF_H
diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt
index 673bef516b13d..45a28723ba6b0 100644
--- a/libc/src/math/generic/CMakeLists.txt
+++ b/libc/src/math/generic/CMakeLists.txt
@@ -3601,3 +3601,15 @@ add_entrypoint_object(
   COMPILE_OPTIONS
     -O3
 )
+
+add_entrypoint_object(
+  f16sqrtf
+  SRCS
+    f16sqrtf.cpp
+  HDRS
+    ../f16sqrtf.h
+  DEPENDS
+    libc.src.__support.FPUtil.sqrt
+  COMPILE_OPTIONS
+    -O3
+)
diff --git a/libc/src/math/generic/f16sqrtf.cpp b/libc/src/math/generic/f16sqrtf.cpp
new file mode 100644
index 0000000000000..1f7ee2df29e86
--- /dev/null
+++ b/libc/src/math/generic/f16sqrtf.cpp
@@ -0,0 +1,19 @@
+//===-- Implementation of f16sqrtf function -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/f16sqrtf.h"
+#include "src/__support/FPUtil/sqrt.h"
+#include "src/__support/common.h"
+
+namespace LIBC_NAMESPACE {
+
+LLVM_LIBC_FUNCTION(float16, f16sqrtf, (float x)) {
+  return fputil::sqrt<float16>(x);
+}
+
+} // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/sqrt.cpp b/libc/src/math/generic/sqrt.cpp
index b4d02785dcb43..f33b0a2cdcf74 100644
--- a/libc/src/math/generic/sqrt.cpp
+++ b/libc/src/math/generic/sqrt.cpp
@@ -12,6 +12,6 @@
 
 namespace LIBC_NAMESPACE {
 
-LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt(x); }
+LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt<double>(x); }
 
 } // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/sqrtf.cpp b/libc/src/math/generic/sqrtf.cpp
index bc74252295b3a..26a53e9077c1c 100644
--- a/libc/src/math/generic/sqrtf.cpp
+++ b/libc/src/math/generic/sqrtf.cpp
@@ -12,6 +12,6 @@
 
 namespace LIBC_NAMESPACE {
 
-LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) { return fputil::sqrt(x); }
+LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) { return fputil::sqrt<float>(x); }
 
 } // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/sqrtf128.cpp b/libc/src/math/generic/sqrtf128.cpp
index 0196c3e0a96ae..70e28ddb692d4 100644
--- a/libc/src/math/generic/sqrtf128.cpp
+++ b/libc/src/math/generic/sqrtf128.cpp
@@ -12,6 +12,8 @@
 
 namespace LIBC_NAMESPACE {
 
-LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) { return fputil::sqrt(x); }
+LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
+  return fputil::sqrt<float128>(x);
+}
 
 } // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/sqrtl.cpp b/libc/src/math/generic/sqrtl.cpp
index b2aaa279f9c2a..9f0cc87853823 100644
--- a/libc/src/math/generic/sqrtl.cpp
+++ b/libc/src/math/generic/sqrtl.cpp
@@ -13,7 +13,7 @@
 namespace LIBC_NAMESPACE {
 
 LLVM_LIBC_FUNCTION(long double, sqrtl, (long double x)) {
-  return fputil::sqrt(x);
+  return fputil::sqrt<long double>(x);
 }
 
 } // namespace LIBC_NAMESPACE
diff --git a/libc/test/src/math/smoke/CMakeLists.txt b/libc/test/src/math/smoke/CMakeLists.txt
index 68cd412b14e9d..d67f5abd2ab1c 100644
--- a/libc/test/src/math/smoke/CMakeLists.txt
+++ b/libc/test/src/math/smoke/CMakeLists.txt
@@ -3543,3 +3543,14 @@ add_fp_unittest(
   DEPENDS
     libc.src.math.totalordermagf16
 )
+
+add_fp_unittest(
+  f16sqrtf_test
+  SUITE
+    libc-math-smoke-tests
+  SRCS
+    f16sqrtf_test.cpp
+  DEPENDS
+    libc.src.math.f16sqrtf
+    libc.src.__support.FPUtil.fp_bits
+)
diff --git a/libc/test/src/math/smoke/SqrtTest.h b/libc/test/src/math/smoke/SqrtTest.h
index 8afacaf01ae42..7731518308fef 100644
--- a/libc/test/src/math/smoke/SqrtTest.h
+++ b/libc/test/src/math/smoke/SqrtTest.h
@@ -6,37 +6,31 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "src/__support/CPP/bit.h"
 #include "test/UnitTest/FEnvSafeTest.h"
 #include "test/UnitTest/FPMatcher.h"
 #include "test/UnitTest/Test.h"
 
-#include "hdr/math_macros.h"
-
-template <typename T>
+template <typename OutType, typename InType>
 class SqrtTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
 
-  DECLARE_SPECIAL_CONSTANTS(T)
-
-  static constexpr StorageType HIDDEN_BIT =
-      StorageType(1) << LIBC_NAMESPACE::fputil::FPBits<T>::FRACTION_LEN;
+  DECLARE_SPECIAL_CONSTANTS(OutType)
 
 public:
-  typedef T (*SqrtFunc)(T);
+  typedef OutType (*SqrtFunc)(InType);
 
   void test_special_numbers(SqrtFunc func) {
     ASSERT_FP_EQ(aNaN, func(aNaN));
     ASSERT_FP_EQ(inf, func(inf));
     ASSERT_FP_EQ(aNaN, func(neg_inf));
-    ASSERT_FP_EQ(0.0, func(0.0));
-    ASSERT_FP_EQ(-0.0, func(-0.0));
-    ASSERT_FP_EQ(aNaN, func(T(-1.0)));
-    ASSERT_FP_EQ(T(1.0), func(T(1.0)));
-    ASSERT_FP_EQ(T(2.0), func(T(4.0)));
-    ASSERT_FP_EQ(T(3.0), func(T(9.0)));
+    ASSERT_FP_EQ(zero, func(zero));
+    ASSERT_FP_EQ(neg_zero, func(neg_zero));
+    ASSERT_FP_EQ(aNaN, func(InType(-1.0)));
+    ASSERT_FP_EQ(OutType(1.0), func(InType(1.0)));
+    ASSERT_FP_EQ(OutType(2.0), func(InType(4.0)));
+    ASSERT_FP_EQ(OutType(3.0), func(InType(9.0)));
   }
 };
 
-#define LIST_SQRT_TESTS(T, func)                                               \
-  using LlvmLibcSqrtTest = SqrtTest<T>;                                        \
+#define LIST_SQRT_TESTS(OutType, InType, func)                                 \
+  using LlvmLibcSqrtTest = SqrtTest<OutType, InType>;                          \
   TEST_F(LlvmLibcSqrtTest, SpecialNumbers) { test_special_numbers(&func); }
diff --git a/libc/test/src/math/smoke/f16sqrtf_test.cpp b/libc/test/src/math/smoke/f16sqrtf_test.cpp
new file mode 100644
index 0000000000000..bf160ccd35c32
--- /dev/null
+++ b/libc/test/src/math/smoke/f16sqrtf_test.cpp
@@ -0,0 +1,13 @@
+//===-- Unittests for f16sqrtf --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "SqrtTest.h"
+
+#include "src/math/f16sqrtf.h"
+
+LIST_SQRT_TESTS(float16, float, LIBC_NAMESPACE::f16sqrtf)
diff --git a/libc/test/src/math/smoke/sqrt_test.cpp b/libc/test/src/math/smoke/sqrt_test.cpp
index 1551b31d6f715..8e2e25dee5440 100644
--- a/libc/test/src/math/smoke/sqrt_test.cpp
+++ b/libc/test/src/math/smoke/sqrt_test.cpp
@@ -10,4 +10,4 @@
 
 #include "src/math/sqrt.h"
 
-LIST_SQRT_TESTS(double, LIBC_NAMESPACE::sqrt)
+LIST_SQRT_TESTS(double, double, LIBC_NAMESPACE::sqrt)
diff --git a/libc/test/src/math/smoke/sqrtf128_test.cpp b/libc/test/src/math/smoke/sqrtf128_test.cpp
index 23397b0623ce5..599e8af80ff00 100644
--- a/libc/test/src/math/smoke/sqrtf128_test.cpp
+++ b/libc/test/src/math/smoke/sqrtf128_test.cpp
@@ -10,4 +10,4 @@
 
 #include "src/math/sqrtf128.h"
 
-LIST_SQRT_TESTS(float128, LIBC_NAMESPACE::sqrtf128)
+LIST_SQRT_TESTS(float128, float128, LIBC_NAMESPACE::sqrtf128)
diff --git a/libc/test/src/math/smoke/sqrtf_test.cpp b/libc/test/src/math/smoke/sqrtf_test.cpp
index 3f2e973325bd0..13093efd24ae5 100644
--- a/libc/test/src/math/smoke/sqrtf_test.cpp
+++ b/libc/test/src/math/smoke/sqrtf_test.cpp
@@ -10,4 +10,4 @@
 
 #include "src/math/sqrtf.h"
 
-LIST_SQRT_TESTS(float, LIBC_NAMESPACE::sqrtf)
+LIST_SQRT_TESTS(float, float, LIBC_NAMESPACE::sqrtf)
diff --git a/libc/test/src/math/smoke/sqrtl_test.cpp b/libc/test/src/math/smoke/sqrtl_test.cpp
index f80bcfb736078..f49daf0f90edb 100644
--- a/libc/test/src/math/smoke/sqrtl_test.cpp
+++ b/libc/test/src/math/smoke/sqrtl_test.cpp
@@ -10,4 +10,4 @@
 
 #include "src/math/sqrtl.h"
 
-LIST_SQRT_TESTS(long double, LIBC_NAMESPACE::sqrtl)
+LIST_SQRT_TESTS(long double, long double, LIBC_NAMESPACE::sqrtl)

>From ae9f258e97befb51a7c4b3ad1f7e2d15993d23e1 Mon Sep 17 00:00:00 2001
From: OverMighty <its.overmighty at gmail.com>
Date: Wed, 12 Jun 2024 15:48:24 +0200
Subject: [PATCH 2/8] [libc][math][c23] Add MPFR exhaustive test for f16sqrtf

---
 libc/test/src/math/exhaustive/CMakeLists.txt  | 15 ++++++
 .../src/math/exhaustive/exhaustive_test.h     | 40 +++++++++++++++
 .../src/math/exhaustive/f16sqrtf_test.cpp     | 25 ++++++++++
 libc/utils/MPFRWrapper/CMakeLists.txt         |  2 +
 libc/utils/MPFRWrapper/MPFRUtils.cpp          | 49 +++++++++++++++++++
 libc/utils/MPFRWrapper/MPFRUtils.h            | 22 ++++++++-
 6 files changed, 152 insertions(+), 1 deletion(-)
 create mode 100644 libc/test/src/math/exhaustive/f16sqrtf_test.cpp

diff --git a/libc/test/src/math/exhaustive/CMakeLists.txt b/libc/test/src/math/exhaustive/CMakeLists.txt
index 938e519aff084..34df8720ed4db 100644
--- a/libc/test/src/math/exhaustive/CMakeLists.txt
+++ b/libc/test/src/math/exhaustive/CMakeLists.txt
@@ -420,3 +420,18 @@ add_fp_unittest(
   LINK_LIBRARIES
     -lpthread
 )
+
+add_fp_unittest(
+  f16sqrtf_test
+  NO_RUN_POSTBUILD
+  NEED_MPFR
+  SUITE
+    libc_math_exhaustive_tests
+  SRCS
+    f16sqrtf_test.cpp
+  DEPENDS
+    .exhaustive_test
+    libc.src.math.f16sqrtf
+  LINK_LIBRARIES
+    -lpthread
+)
diff --git a/libc/test/src/math/exhaustive/exhaustive_test.h b/libc/test/src/math/exhaustive/exhaustive_test.h
index c4ae382688a03..1f8daf497ab2f 100644
--- a/libc/test/src/math/exhaustive/exhaustive_test.h
+++ b/libc/test/src/math/exhaustive/exhaustive_test.h
@@ -68,6 +68,41 @@ struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
   }
 };
 
+template <typename OutType, typename InType>
+using UnaryNarrowerOp = OutType(InType);
+
+template <typename OutType, typename InType, mpfr::Operation Op,
+          UnaryNarrowerOp<OutType, InType> Func>
+struct UnaryNarrowerOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
+  using FloatType = InType;
+  using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
+  using StorageType = typename FPBits::StorageType;
+
+  static constexpr UnaryNarrowerOp<OutType, FloatType> *FUNC = Func;
+
+  // Check in a range, return the number of failures.
+  uint64_t check(StorageType start, StorageType stop,
+                 mpfr::RoundingMode rounding) {
+    mpfr::ForceRoundingMode r(rounding);
+    if (!r.success)
+      return (stop > start);
+    StorageType bits = start;
+    uint64_t failed = 0;
+    do {
+      FPBits xbits(bits);
+      FloatType x = xbits.get_val();
+      bool correct =
+          TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, FUNC(x), 0.5, rounding);
+      failed += (!correct);
+      // Uncomment to print out failed values.
+      if (!correct) {
+        EXPECT_MPFR_MATCH_ROUNDING(Op, x, FUNC(x), 0.5, rounding);
+      }
+    } while (bits++ < stop);
+    return failed;
+  }
+};
+
 // Checker class needs inherit from LIBC_NAMESPACE::testing::Test and provide
 //   StorageType and check method.
 template <typename Checker>
@@ -170,3 +205,8 @@ struct LlvmLibcExhaustiveMathTest
 template <typename FloatType, mpfr::Operation Op, UnaryOp<FloatType> Func>
 using LlvmLibcUnaryOpExhaustiveMathTest =
     LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, Op, Func>>;
+
+template <typename OutType, typename InType, mpfr::Operation Op,
+          UnaryNarrowerOp<OutType, InType> Func>
+using LlvmLibcUnaryNarrowerOpExhaustiveMathTest = LlvmLibcExhaustiveMathTest<
+    UnaryNarrowerOpChecker<OutType, InType, Op, Func>>;
diff --git a/libc/test/src/math/exhaustive/f16sqrtf_test.cpp b/libc/test/src/math/exhaustive/f16sqrtf_test.cpp
new file mode 100644
index 0000000000000..5bc04f6bdc7cf
--- /dev/null
+++ b/libc/test/src/math/exhaustive/f16sqrtf_test.cpp
@@ -0,0 +1,25 @@
+//===-- Exhaustive test for f16sqrtf --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "exhaustive_test.h"
+#include "src/math/f16sqrtf.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
+
+using LlvmLibcF16sqrtfExhaustiveTest =
+    LlvmLibcUnaryNarrowerOpExhaustiveMathTest<
+        float16, float, mpfr::Operation::Sqrt, LIBC_NAMESPACE::f16sqrtf>;
+
+// Range: [0, Inf];
+static constexpr uint32_t POS_START = 0x0000'0000U;
+static constexpr uint32_t POS_STOP = 0x7f80'0000U;
+
+TEST_F(LlvmLibcF16sqrtfExhaustiveTest, PostiveRange) {
+  test_full_range_all_roundings(POS_START, POS_STOP);
+}
diff --git a/libc/utils/MPFRWrapper/CMakeLists.txt b/libc/utils/MPFRWrapper/CMakeLists.txt
index 6af6fd7707041..08fae564dc273 100644
--- a/libc/utils/MPFRWrapper/CMakeLists.txt
+++ b/libc/utils/MPFRWrapper/CMakeLists.txt
@@ -7,7 +7,9 @@ if(LIBC_TESTS_CAN_USE_MPFR)
   target_compile_options(libcMPFRWrapper PRIVATE -O3)
   add_dependencies(
     libcMPFRWrapper
+    libc.src.__support.CPP.array
     libc.src.__support.CPP.string_view
+    libc.src.__support.CPP.stringstream
     libc.src.__support.CPP.type_traits
     libc.src.__support.FPUtil.fp_bits
     libc.src.__support.FPUtil.fpbits_str
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp
index 6918139fa83b7..3de096fca3d04 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp
@@ -8,8 +8,10 @@
 
 #include "MPFRUtils.h"
 
+#include "src/__support/CPP/array.h"
 #include "src/__support/CPP/string.h"
 #include "src/__support/CPP/string_view.h"
+#include "src/__support/CPP/stringstream.h"
 #include "src/__support/FPUtil/FPBits.h"
 #include "src/__support/FPUtil/fpbits_str.h"
 #include "src/__support/macros/properties/types.h"
@@ -790,6 +792,37 @@ template void explain_unary_operation_single_output_error<float16>(
     Operation op, float16, float16, double, RoundingMode);
 #endif
 
+template <typename OutType, typename InType>
+void explain_unary_narrower_operation_single_output_error(
+    Operation op, InType input, OutType matchValue, double ulp_tolerance,
+    RoundingMode rounding) {
+  unsigned int precision = get_precision<InType>(ulp_tolerance);
+  MPFRNumber mpfrInput(input, precision);
+  MPFRNumber mpfr_result;
+  mpfr_result = unary_operation(op, input, precision, rounding);
+  MPFRNumber mpfrMatchValue(matchValue);
+  cpp::array<char, 4096> msg_data;
+  cpp::StringStream msg(msg_data);
+  msg << "Match value not within tolerance value of MPFR result:\n"
+      << "  Input decimal: " << mpfrInput.str() << '\n';
+  msg << "     Input bits: " << str(FPBits<InType>(input)) << '\n';
+  msg << '\n' << "  Match decimal: " << mpfrMatchValue.str() << '\n';
+  msg << "     Match bits: " << str(FPBits<OutType>(matchValue)) << '\n';
+  msg << '\n' << "    MPFR result: " << mpfr_result.str() << '\n';
+  msg << "   MPFR rounded: " << str(FPBits<OutType>(mpfr_result.as<OutType>()))
+      << '\n';
+  msg << '\n';
+  msg << "      ULP error: " << mpfr_result.ulp_as_mpfr_number(matchValue).str()
+      << '\n';
+  tlog << msg.str();
+}
+
+#ifdef LIBC_TYPES_HAS_FLOAT16
+template void
+explain_unary_narrower_operation_single_output_error<float16, float>(
+    Operation op, float, float16, double, RoundingMode);
+#endif
+
 template <typename T>
 void explain_unary_operation_two_outputs_error(
     Operation op, T input, const BinaryOutput<T> &libc_result,
@@ -974,6 +1007,22 @@ template bool compare_unary_operation_single_output<float16>(Operation, float16,
                                                              RoundingMode);
 #endif
 
+template <typename OutType, typename InType>
+bool compare_unary_narrower_operation_single_output(Operation op, InType input,
+                                                    OutType libc_result,
+                                                    double ulp_tolerance,
+                                                    RoundingMode rounding) {
+  unsigned int precision = get_precision<InType>(ulp_tolerance);
+  MPFRNumber mpfr_result;
+  mpfr_result = unary_operation(op, input, precision, rounding);
+  double ulp = mpfr_result.ulp(libc_result);
+  return (ulp <= ulp_tolerance);
+}
+#ifdef LIBC_TYPES_HAS_FLOAT16
+template bool compare_unary_narrower_operation_single_output<float16, float>(
+    Operation, float, float16, double, RoundingMode);
+#endif
+
 template <typename T>
 bool compare_unary_operation_two_outputs(Operation op, T input,
                                          const BinaryOutput<T> &libc_result,
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h
index d2f73e2628e16..adccb19ea091c 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.h
+++ b/libc/utils/MPFRWrapper/MPFRUtils.h
@@ -133,6 +133,11 @@ template <typename T>
 bool compare_unary_operation_single_output(Operation op, T input, T libc_output,
                                            double ulp_tolerance,
                                            RoundingMode rounding);
+template <typename OutType, typename InType>
+bool compare_unary_narrower_operation_single_output(Operation op, InType input,
+                                                    OutType libc_output,
+                                                    double ulp_tolerance,
+                                                    RoundingMode rounding);
 template <typename T>
 bool compare_unary_operation_two_outputs(Operation op, T input,
                                          const BinaryOutput<T> &libc_output,
@@ -162,6 +167,10 @@ void explain_unary_operation_single_output_error(Operation op, T input,
                                                  T match_value,
                                                  double ulp_tolerance,
                                                  RoundingMode rounding);
+template <typename OutType, typename InType>
+void explain_unary_narrower_operation_single_output_error(
+    Operation op, InType input, OutType match_value, double ulp_tolerance,
+    RoundingMode rounding);
 template <typename T>
 void explain_unary_operation_two_outputs_error(
     Operation op, T input, const BinaryOutput<T> &match_value,
@@ -217,6 +226,11 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
                                                  rounding);
   }
 
+  template <typename T, typename U> bool match(T in, U out) {
+    return compare_unary_narrower_operation_single_output(
+        op, in, out, ulp_tolerance, rounding);
+  }
+
   template <typename T> bool match(T in, const BinaryOutput<T> &out) {
     return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance,
                                                rounding);
@@ -243,6 +257,11 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
                                                 rounding);
   }
 
+  template <typename T, typename U> void explain_error(T in, U out) {
+    explain_unary_narrower_operation_single_output_error(
+        op, in, out, ulp_tolerance, rounding);
+  }
+
   template <typename T> void explain_error(T in, const BinaryOutput<T> &out) {
     explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance,
                                               rounding);
@@ -271,7 +290,8 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
 // types.
 template <Operation op, typename InputType, typename OutputType>
 constexpr bool is_valid_operation() {
-  return (Operation::BeginUnaryOperationsSingleOutput < op &&
+  return (op == Operation::Sqrt) ||
+         (Operation::BeginUnaryOperationsSingleOutput < op &&
           op < Operation::EndUnaryOperationsSingleOutput &&
           cpp::is_same_v<InputType, OutputType> &&
           cpp::is_floating_point_v<InputType>) ||

>From d515f1833400091373b020b15621739e359f44b9 Mon Sep 17 00:00:00 2001
From: OverMighty <its.overmighty at gmail.com>
Date: Wed, 12 Jun 2024 17:23:39 +0200
Subject: [PATCH 3/8] [libc][math][c23] Fix includes and refactor exhaustive
 test support

---
 .../__support/FPUtil/generic/CMakeLists.txt   |  1 +
 libc/src/__support/FPUtil/generic/sqrt.h      |  3 +-
 libc/src/__support/FPUtil/sqrt.h              |  4 +-
 .../src/math/exhaustive/exhaustive_test.h     | 20 ++--
 .../src/math/exhaustive/f16sqrtf_test.cpp     |  2 +-
 libc/test/src/math/smoke/CMakeLists.txt       | 31 +++---
 libc/utils/MPFRWrapper/CMakeLists.txt         |  2 +-
 libc/utils/MPFRWrapper/MPFRUtils.cpp          | 95 ++++++-------------
 libc/utils/MPFRWrapper/MPFRUtils.h            | 46 ++++-----
 9 files changed, 82 insertions(+), 122 deletions(-)

diff --git a/libc/src/__support/FPUtil/generic/CMakeLists.txt b/libc/src/__support/FPUtil/generic/CMakeLists.txt
index 09eede1570962..595656e3e8d90 100644
--- a/libc/src/__support/FPUtil/generic/CMakeLists.txt
+++ b/libc/src/__support/FPUtil/generic/CMakeLists.txt
@@ -4,6 +4,7 @@ add_header_library(
     sqrt.h
     sqrt_80_bit_long_double.h
   DEPENDS
+    libc.hdr.fenv_macros
     libc.src.__support.common
     libc.src.__support.CPP.bit
     libc.src.__support.CPP.type_traits
diff --git a/libc/src/__support/FPUtil/generic/sqrt.h b/libc/src/__support/FPUtil/generic/sqrt.h
index 4c95053217228..17446a12c298b 100644
--- a/libc/src/__support/FPUtil/generic/sqrt.h
+++ b/libc/src/__support/FPUtil/generic/sqrt.h
@@ -17,7 +17,8 @@
 #include "src/__support/FPUtil/rounding_mode.h"
 #include "src/__support/common.h"
 #include "src/__support/uint128.h"
-#include <fenv.h>
+
+#include "hdr/fenv_macros.h"
 
 namespace LIBC_NAMESPACE {
 namespace fputil {
diff --git a/libc/src/__support/FPUtil/sqrt.h b/libc/src/__support/FPUtil/sqrt.h
index d9c30c586bb0d..eb86ddfa89d8e 100644
--- a/libc/src/__support/FPUtil/sqrt.h
+++ b/libc/src/__support/FPUtil/sqrt.h
@@ -13,9 +13,7 @@
 #include "src/__support/macros/properties/cpu_features.h"
 
 #if defined(LIBC_TARGET_ARCH_IS_X86_64) && defined(LIBC_TARGET_CPU_HAS_SSE2)
-// #include "x86_64/sqrt.h"
-// TODO
-#include "generic/sqrt.h"
+#include "x86_64/sqrt.h"
 #elif defined(LIBC_TARGET_ARCH_IS_AARCH64)
 #include "aarch64/sqrt.h"
 #elif defined(LIBC_TARGET_ARCH_IS_ANY_RISCV)
diff --git a/libc/test/src/math/exhaustive/exhaustive_test.h b/libc/test/src/math/exhaustive/exhaustive_test.h
index 1f8daf497ab2f..ec6e9e79e36e8 100644
--- a/libc/test/src/math/exhaustive/exhaustive_test.h
+++ b/libc/test/src/math/exhaustive/exhaustive_test.h
@@ -69,16 +69,16 @@ struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
 };
 
 template <typename OutType, typename InType>
-using UnaryNarrowerOp = OutType(InType);
+using UnaryNarrowingOp = OutType(InType);
 
 template <typename OutType, typename InType, mpfr::Operation Op,
-          UnaryNarrowerOp<OutType, InType> Func>
-struct UnaryNarrowerOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
+          UnaryNarrowingOp<OutType, InType> Func>
+struct UnaryNarrowingOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
   using FloatType = InType;
   using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
   using StorageType = typename FPBits::StorageType;
 
-  static constexpr UnaryNarrowerOp<OutType, FloatType> *FUNC = Func;
+  static constexpr UnaryNarrowingOp<OutType, FloatType> *FUNC = Func;
 
   // Check in a range, return the number of failures.
   uint64_t check(StorageType start, StorageType stop,
@@ -95,9 +95,9 @@ struct UnaryNarrowerOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
           TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, FUNC(x), 0.5, rounding);
       failed += (!correct);
       // Uncomment to print out failed values.
-      if (!correct) {
-        EXPECT_MPFR_MATCH_ROUNDING(Op, x, FUNC(x), 0.5, rounding);
-      }
+      // if (!correct) {
+      //   EXPECT_MPFR_MATCH_ROUNDING(Op, x, FUNC(x), 0.5, rounding);
+      // }
     } while (bits++ < stop);
     return failed;
   }
@@ -207,6 +207,6 @@ using LlvmLibcUnaryOpExhaustiveMathTest =
     LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, Op, Func>>;
 
 template <typename OutType, typename InType, mpfr::Operation Op,
-          UnaryNarrowerOp<OutType, InType> Func>
-using LlvmLibcUnaryNarrowerOpExhaustiveMathTest = LlvmLibcExhaustiveMathTest<
-    UnaryNarrowerOpChecker<OutType, InType, Op, Func>>;
+          UnaryNarrowingOp<OutType, InType> Func>
+using LlvmLibcUnaryNarrowingOpExhaustiveMathTest = LlvmLibcExhaustiveMathTest<
+    UnaryNarrowingOpChecker<OutType, InType, Op, Func>>;
diff --git a/libc/test/src/math/exhaustive/f16sqrtf_test.cpp b/libc/test/src/math/exhaustive/f16sqrtf_test.cpp
index 5bc04f6bdc7cf..3a42ff8e0725d 100644
--- a/libc/test/src/math/exhaustive/f16sqrtf_test.cpp
+++ b/libc/test/src/math/exhaustive/f16sqrtf_test.cpp
@@ -13,7 +13,7 @@
 namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
 
 using LlvmLibcF16sqrtfExhaustiveTest =
-    LlvmLibcUnaryNarrowerOpExhaustiveMathTest<
+    LlvmLibcUnaryNarrowingOpExhaustiveMathTest<
         float16, float, mpfr::Operation::Sqrt, LIBC_NAMESPACE::f16sqrtf>;
 
 // Range: [0, Inf];
diff --git a/libc/test/src/math/smoke/CMakeLists.txt b/libc/test/src/math/smoke/CMakeLists.txt
index d67f5abd2ab1c..3bb87d2b0d0f3 100644
--- a/libc/test/src/math/smoke/CMakeLists.txt
+++ b/libc/test/src/math/smoke/CMakeLists.txt
@@ -2504,9 +2504,10 @@ add_fp_unittest(
     libc-math-smoke-tests
   SRCS
     sqrtf_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrtf
-    libc.src.__support.FPUtil.fp_bits
 )
 
 add_fp_unittest(
@@ -2515,9 +2516,10 @@ add_fp_unittest(
     libc-math-smoke-tests
   SRCS
     sqrt_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrt
-    libc.src.__support.FPUtil.fp_bits
 )
 
 add_fp_unittest(
@@ -2526,9 +2528,10 @@ add_fp_unittest(
     libc-math-smoke-tests
   SRCS
     sqrtl_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrtl
-    libc.src.__support.FPUtil.fp_bits
 )
 
 add_fp_unittest(
@@ -2537,9 +2540,10 @@ add_fp_unittest(
     libc-math-smoke-tests
   SRCS
     sqrtf128_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrtf128
-    libc.src.__support.FPUtil.fp_bits
 )
 
 add_fp_unittest(
@@ -2548,9 +2552,9 @@ add_fp_unittest(
     libc-math-smoke-tests
   SRCS
     generic_sqrtf_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
-    libc.src.math.sqrtf
-    libc.src.__support.FPUtil.fp_bits
     libc.src.__support.FPUtil.generic.sqrt
   COMPILE_OPTIONS
     -O3
@@ -2562,9 +2566,9 @@ add_fp_unittest(
     libc-math-smoke-tests
   SRCS
     generic_sqrt_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
-    libc.src.math.sqrt
-    libc.src.__support.FPUtil.fp_bits
     libc.src.__support.FPUtil.generic.sqrt
   COMPILE_OPTIONS
     -O3
@@ -2576,9 +2580,9 @@ add_fp_unittest(
     libc-math-smoke-tests
   SRCS
     generic_sqrtl_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
-    libc.src.math.sqrtl
-    libc.src.__support.FPUtil.fp_bits
     libc.src.__support.FPUtil.generic.sqrt
   COMPILE_OPTIONS
     -O3
@@ -2590,9 +2594,9 @@ add_fp_unittest(
     libc-math-smoke-tests
   SRCS
     generic_sqrtf128_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
-    libc.src.math.sqrtf128
-    libc.src.__support.FPUtil.fp_bits
     libc.src.__support.FPUtil.generic.sqrt
   COMPILE_OPTIONS
     -O3
@@ -3550,7 +3554,8 @@ add_fp_unittest(
     libc-math-smoke-tests
   SRCS
     f16sqrtf_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.f16sqrtf
-    libc.src.__support.FPUtil.fp_bits
 )
diff --git a/libc/utils/MPFRWrapper/CMakeLists.txt b/libc/utils/MPFRWrapper/CMakeLists.txt
index 08fae564dc273..e74b02204ed6f 100644
--- a/libc/utils/MPFRWrapper/CMakeLists.txt
+++ b/libc/utils/MPFRWrapper/CMakeLists.txt
@@ -8,8 +8,8 @@ if(LIBC_TESTS_CAN_USE_MPFR)
   add_dependencies(
     libcMPFRWrapper
     libc.src.__support.CPP.array
-    libc.src.__support.CPP.string_view
     libc.src.__support.CPP.stringstream
+    libc.src.__support.CPP.string_view
     libc.src.__support.CPP.type_traits
     libc.src.__support.FPUtil.fp_bits
     libc.src.__support.FPUtil.fpbits_str
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp
index 3de096fca3d04..e047a71566792 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp
@@ -757,26 +757,32 @@ ternary_operation_one_output(Operation op, InputType x, InputType y,
 // to build the complete error messages before sending it to the outstream `OS`
 // once at the end.  This will stop the error messages from interleaving when
 // the tests are running concurrently.
-template <typename T>
-void explain_unary_operation_single_output_error(Operation op, T input,
-                                                 T matchValue,
+template <typename InputType, typename OutputType>
+void explain_unary_operation_single_output_error(Operation op, InputType input,
+                                                 OutputType matchValue,
                                                  double ulp_tolerance,
                                                  RoundingMode rounding) {
-  unsigned int precision = get_precision<T>(ulp_tolerance);
+  unsigned int precision = get_precision<InputType>(ulp_tolerance);
   MPFRNumber mpfrInput(input, precision);
   MPFRNumber mpfr_result;
   mpfr_result = unary_operation(op, input, precision, rounding);
   MPFRNumber mpfrMatchValue(matchValue);
-  tlog << "Match value not within tolerance value of MPFR result:\n"
-       << "  Input decimal: " << mpfrInput.str() << '\n';
-  tlog << "     Input bits: " << str(FPBits<T>(input)) << '\n';
-  tlog << '\n' << "  Match decimal: " << mpfrMatchValue.str() << '\n';
-  tlog << "     Match bits: " << str(FPBits<T>(matchValue)) << '\n';
-  tlog << '\n' << "    MPFR result: " << mpfr_result.str() << '\n';
-  tlog << "   MPFR rounded: " << str(FPBits<T>(mpfr_result.as<T>())) << '\n';
-  tlog << '\n';
-  tlog << "      ULP error: "
-       << mpfr_result.ulp_as_mpfr_number(matchValue).str() << '\n';
+  cpp::array<char, 1024> msg_buf;
+  cpp::StringStream msg(msg_buf);
+  msg << "Match value not within tolerance value of MPFR result:\n"
+      << "  Input decimal: " << mpfrInput.str() << '\n';
+  msg << "     Input bits: " << str(FPBits<InputType>(input)) << '\n';
+  msg << '\n' << "  Match decimal: " << mpfrMatchValue.str() << '\n';
+  msg << "     Match bits: " << str(FPBits<OutputType>(matchValue)) << '\n';
+  msg << '\n' << "    MPFR result: " << mpfr_result.str() << '\n';
+  msg << "   MPFR rounded: "
+      << str(FPBits<OutputType>(mpfr_result.as<OutputType>())) << '\n';
+  msg << '\n';
+  msg << "      ULP error: " << mpfr_result.ulp_as_mpfr_number(matchValue).str()
+      << '\n';
+  if (msg.overflow())
+    __builtin_unreachable();
+  tlog << msg.str();
 }
 
 template void explain_unary_operation_single_output_error<float>(Operation op,
@@ -790,37 +796,10 @@ template void explain_unary_operation_single_output_error<long double>(
 #ifdef LIBC_TYPES_HAS_FLOAT16
 template void explain_unary_operation_single_output_error<float16>(
     Operation op, float16, float16, double, RoundingMode);
-#endif
-
-template <typename OutType, typename InType>
-void explain_unary_narrower_operation_single_output_error(
-    Operation op, InType input, OutType matchValue, double ulp_tolerance,
-    RoundingMode rounding) {
-  unsigned int precision = get_precision<InType>(ulp_tolerance);
-  MPFRNumber mpfrInput(input, precision);
-  MPFRNumber mpfr_result;
-  mpfr_result = unary_operation(op, input, precision, rounding);
-  MPFRNumber mpfrMatchValue(matchValue);
-  cpp::array<char, 4096> msg_data;
-  cpp::StringStream msg(msg_data);
-  msg << "Match value not within tolerance value of MPFR result:\n"
-      << "  Input decimal: " << mpfrInput.str() << '\n';
-  msg << "     Input bits: " << str(FPBits<InType>(input)) << '\n';
-  msg << '\n' << "  Match decimal: " << mpfrMatchValue.str() << '\n';
-  msg << "     Match bits: " << str(FPBits<OutType>(matchValue)) << '\n';
-  msg << '\n' << "    MPFR result: " << mpfr_result.str() << '\n';
-  msg << "   MPFR rounded: " << str(FPBits<OutType>(mpfr_result.as<OutType>()))
-      << '\n';
-  msg << '\n';
-  msg << "      ULP error: " << mpfr_result.ulp_as_mpfr_number(matchValue).str()
-      << '\n';
-  tlog << msg.str();
-}
-
-#ifdef LIBC_TYPES_HAS_FLOAT16
-template void
-explain_unary_narrower_operation_single_output_error<float16, float>(
-    Operation op, float, float16, double, RoundingMode);
+template void explain_unary_operation_single_output_error<float>(Operation op,
+                                                                 float, float16,
+                                                                 double,
+                                                                 RoundingMode);
 #endif
 
 template <typename T>
@@ -982,11 +961,12 @@ template void explain_ternary_operation_one_output_error<long double>(
     Operation, const TernaryInput<long double> &, long double, double,
     RoundingMode);
 
-template <typename T>
-bool compare_unary_operation_single_output(Operation op, T input, T libc_result,
+template <typename InputType, typename OutputType>
+bool compare_unary_operation_single_output(Operation op, InputType input,
+                                           OutputType libc_result,
                                            double ulp_tolerance,
                                            RoundingMode rounding) {
-  unsigned int precision = get_precision<T>(ulp_tolerance);
+  unsigned int precision = get_precision<InputType>(ulp_tolerance);
   MPFRNumber mpfr_result;
   mpfr_result = unary_operation(op, input, precision, rounding);
   double ulp = mpfr_result.ulp(libc_result);
@@ -1005,22 +985,9 @@ template bool compare_unary_operation_single_output<long double>(
 template bool compare_unary_operation_single_output<float16>(Operation, float16,
                                                              float16, double,
                                                              RoundingMode);
-#endif
-
-template <typename OutType, typename InType>
-bool compare_unary_narrower_operation_single_output(Operation op, InType input,
-                                                    OutType libc_result,
-                                                    double ulp_tolerance,
-                                                    RoundingMode rounding) {
-  unsigned int precision = get_precision<InType>(ulp_tolerance);
-  MPFRNumber mpfr_result;
-  mpfr_result = unary_operation(op, input, precision, rounding);
-  double ulp = mpfr_result.ulp(libc_result);
-  return (ulp <= ulp_tolerance);
-}
-#ifdef LIBC_TYPES_HAS_FLOAT16
-template bool compare_unary_narrower_operation_single_output<float16, float>(
-    Operation, float, float16, double, RoundingMode);
+template bool compare_unary_operation_single_output<float>(Operation, float,
+                                                           float16, double,
+                                                           RoundingMode);
 #endif
 
 template <typename T>
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h
index adccb19ea091c..32488343b9d2e 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.h
+++ b/libc/utils/MPFRWrapper/MPFRUtils.h
@@ -129,15 +129,11 @@ struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
   static constexpr bool VALUE = cpp::is_floating_point_v<T>;
 };
 
-template <typename T>
-bool compare_unary_operation_single_output(Operation op, T input, T libc_output,
+template <typename InputType, typename OutputType>
+bool compare_unary_operation_single_output(Operation op, InputType input,
+                                           OutputType libc_output,
                                            double ulp_tolerance,
                                            RoundingMode rounding);
-template <typename OutType, typename InType>
-bool compare_unary_narrower_operation_single_output(Operation op, InType input,
-                                                    OutType libc_output,
-                                                    double ulp_tolerance,
-                                                    RoundingMode rounding);
 template <typename T>
 bool compare_unary_operation_two_outputs(Operation op, T input,
                                          const BinaryOutput<T> &libc_output,
@@ -162,15 +158,11 @@ bool compare_ternary_operation_one_output(Operation op,
                                           T libc_output, double ulp_tolerance,
                                           RoundingMode rounding);
 
-template <typename T>
-void explain_unary_operation_single_output_error(Operation op, T input,
-                                                 T match_value,
+template <typename InputType, typename OutputType>
+void explain_unary_operation_single_output_error(Operation op, InputType input,
+                                                 OutputType match_value,
                                                  double ulp_tolerance,
                                                  RoundingMode rounding);
-template <typename OutType, typename InType>
-void explain_unary_narrower_operation_single_output_error(
-    Operation op, InType input, OutType match_value, double ulp_tolerance,
-    RoundingMode rounding);
 template <typename T>
 void explain_unary_operation_two_outputs_error(
     Operation op, T input, const BinaryOutput<T> &match_value,
@@ -214,23 +206,18 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
   // This method is marked with NOLINT because the name `explainError` does not
   // conform to the coding style.
   void explainError() override { // NOLINT
-    explain_error(input, match_value);
+    explain_error<InputType, OutputType>(input, match_value);
   }
 
   // Whether the `explainError` step is skipped or not.
   bool is_silent() const override { return silent; }
 
 private:
-  template <typename T> bool match(T in, T out) {
+  template <typename T, typename U> bool match(T in, U out) {
     return compare_unary_operation_single_output(op, in, out, ulp_tolerance,
                                                  rounding);
   }
 
-  template <typename T, typename U> bool match(T in, U out) {
-    return compare_unary_narrower_operation_single_output(
-        op, in, out, ulp_tolerance, rounding);
-  }
-
   template <typename T> bool match(T in, const BinaryOutput<T> &out) {
     return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance,
                                                rounding);
@@ -252,16 +239,12 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
                                                 rounding);
   }
 
-  template <typename T> void explain_error(T in, T out) {
+  template <typename T, typename U>
+  void explain_error(InputType in, OutputType out) {
     explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
                                                 rounding);
   }
 
-  template <typename T, typename U> void explain_error(T in, U out) {
-    explain_unary_narrower_operation_single_output_error(
-        op, in, out, ulp_tolerance, rounding);
-  }
-
   template <typename T> void explain_error(T in, const BinaryOutput<T> &out) {
     explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance,
                                               rounding);
@@ -290,8 +273,13 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
 // types.
 template <Operation op, typename InputType, typename OutputType>
 constexpr bool is_valid_operation() {
-  return (op == Operation::Sqrt) ||
-         (Operation::BeginUnaryOperationsSingleOutput < op &&
+  constexpr bool IS_NARROWING_OP = op == Operation::Sqrt &&
+                                   cpp::is_floating_point_v<InputType> &&
+                                   cpp::is_floating_point_v<OutputType> &&
+                                   sizeof(OutputType) <= sizeof(InputType);
+  if (IS_NARROWING_OP)
+    return true;
+  return (Operation::BeginUnaryOperationsSingleOutput < op &&
           op < Operation::EndUnaryOperationsSingleOutput &&
           cpp::is_same_v<InputType, OutputType> &&
           cpp::is_floating_point_v<InputType>) ||

>From d373c141a9a7685c5b5d2f5fa67f6f404f788eb3 Mon Sep 17 00:00:00 2001
From: OverMighty <its.overmighty at gmail.com>
Date: Wed, 12 Jun 2024 17:35:53 +0200
Subject: [PATCH 4/8] [libc][math][c23] Remove unnecessary special case
 handling from generic sqrt

---
 libc/src/__support/FPUtil/generic/sqrt.h | 14 --------------
 1 file changed, 14 deletions(-)

diff --git a/libc/src/__support/FPUtil/generic/sqrt.h b/libc/src/__support/FPUtil/generic/sqrt.h
index 17446a12c298b..d6e894fdfe021 100644
--- a/libc/src/__support/FPUtil/generic/sqrt.h
+++ b/libc/src/__support/FPUtil/generic/sqrt.h
@@ -181,20 +181,6 @@ sqrt(InType x) {
           }
         }
 
-        if (x_exp == OutFPBits::MAX_BIASED_EXPONENT - 1 &&
-            y == OutFPBits::max_normal().uintval() && (rb || sticky)) {
-          switch (quick_get_round()) {
-          case FE_TONEAREST:
-            if (rb)
-              return OutFPBits::inf().get_val();
-            return OutFPBits::max_normal().get_val();
-          case FE_UPWARD:
-            return OutFPBits::inf().get_val();
-          default:
-            return OutFPBits::max_normal().get_val();
-          }
-        }
-
         if (x_exp <
             -OutFPBits::EXP_BIAS - OutFPBits::SIG_LEN + EXTRA_FRACTION_LEN) {
           switch (quick_get_round()) {

>From cd9092e84936fe496863c617bdebec6468004599 Mon Sep 17 00:00:00 2001
From: OverMighty <its.overmighty at gmail.com>
Date: Wed, 12 Jun 2024 20:40:18 +0200
Subject: [PATCH 5/8] [libc][math] Merge UnaryNarrowingOpChecker into
 UnaryOpChecker

---
 .../src/math/exhaustive/exhaustive_test.h     | 55 ++++---------------
 1 file changed, 10 insertions(+), 45 deletions(-)

diff --git a/libc/test/src/math/exhaustive/exhaustive_test.h b/libc/test/src/math/exhaustive/exhaustive_test.h
index ec6e9e79e36e8..13e272783250b 100644
--- a/libc/test/src/math/exhaustive/exhaustive_test.h
+++ b/libc/test/src/math/exhaustive/exhaustive_test.h
@@ -35,51 +35,16 @@
 //   LlvmLibcUnaryOpExhaustiveMathTest<FloatType, Op, Func>.
 namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
 
-template <typename T> using UnaryOp = T(T);
-
-template <typename T, mpfr::Operation Op, UnaryOp<T> Func>
-struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
-  using FloatType = T;
-  using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
-  using StorageType = typename FPBits::StorageType;
-
-  static constexpr UnaryOp<FloatType> *FUNC = Func;
-
-  // Check in a range, return the number of failures.
-  uint64_t check(StorageType start, StorageType stop,
-                 mpfr::RoundingMode rounding) {
-    mpfr::ForceRoundingMode r(rounding);
-    if (!r.success)
-      return (stop > start);
-    StorageType bits = start;
-    uint64_t failed = 0;
-    do {
-      FPBits xbits(bits);
-      FloatType x = xbits.get_val();
-      bool correct =
-          TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, FUNC(x), 0.5, rounding);
-      failed += (!correct);
-      // Uncomment to print out failed values.
-      // if (!correct) {
-      //   TEST_MPFR_MATCH(Op::Operation, x, Op::func(x), 0.5, rounding);
-      // }
-    } while (bits++ < stop);
-    return failed;
-  }
-};
-
-template <typename OutType, typename InType>
-using UnaryNarrowingOp = OutType(InType);
+template <typename OutType, typename InType = OutType>
+using UnaryOp = OutType(InType);
 
 template <typename OutType, typename InType, mpfr::Operation Op,
-          UnaryNarrowingOp<OutType, InType> Func>
-struct UnaryNarrowingOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
+          UnaryOp<OutType, InType> Func>
+struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
   using FloatType = InType;
   using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
   using StorageType = typename FPBits::StorageType;
 
-  static constexpr UnaryNarrowingOp<OutType, FloatType> *FUNC = Func;
-
   // Check in a range, return the number of failures.
   uint64_t check(StorageType start, StorageType stop,
                  mpfr::RoundingMode rounding) {
@@ -92,11 +57,11 @@ struct UnaryNarrowingOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
       FPBits xbits(bits);
       FloatType x = xbits.get_val();
       bool correct =
-          TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, FUNC(x), 0.5, rounding);
+          TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, Func(x), 0.5, rounding);
       failed += (!correct);
       // Uncomment to print out failed values.
       // if (!correct) {
-      //   EXPECT_MPFR_MATCH_ROUNDING(Op, x, FUNC(x), 0.5, rounding);
+      //   EXPECT_MPFR_MATCH_ROUNDING(Op, x, Func(x), 0.5, rounding);
       // }
     } while (bits++ < stop);
     return failed;
@@ -204,9 +169,9 @@ struct LlvmLibcExhaustiveMathTest
 
 template <typename FloatType, mpfr::Operation Op, UnaryOp<FloatType> Func>
 using LlvmLibcUnaryOpExhaustiveMathTest =
-    LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, Op, Func>>;
+    LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, FloatType, Op, Func>>;
 
 template <typename OutType, typename InType, mpfr::Operation Op,
-          UnaryNarrowingOp<OutType, InType> Func>
-using LlvmLibcUnaryNarrowingOpExhaustiveMathTest = LlvmLibcExhaustiveMathTest<
-    UnaryNarrowingOpChecker<OutType, InType, Op, Func>>;
+          UnaryOp<OutType, InType> Func>
+using LlvmLibcUnaryNarrowingOpExhaustiveMathTest =
+    LlvmLibcExhaustiveMathTest<UnaryOpChecker<OutType, InType, Op, Func>>;

>From e0552647f781a9b3c2fcd5348b1f6531680e08ab Mon Sep 17 00:00:00 2001
From: OverMighty <its.overmighty at gmail.com>
Date: Wed, 12 Jun 2024 21:04:01 +0200
Subject: [PATCH 6/8] [libc][math] Fix build breakage of other exhaustive tests

---
 libc/src/math/generic/acosf.cpp    | 2 +-
 libc/src/math/generic/acoshf.cpp   | 4 ++--
 libc/src/math/generic/asinf.cpp    | 2 +-
 libc/src/math/generic/asinhf.cpp   | 6 +++---
 libc/src/math/generic/hypotf.cpp   | 2 +-
 libc/utils/MPFRWrapper/MPFRUtils.h | 5 ++---
 6 files changed, 10 insertions(+), 11 deletions(-)

diff --git a/libc/src/math/generic/acosf.cpp b/libc/src/math/generic/acosf.cpp
index e6e28d43ef61f..f02edec267174 100644
--- a/libc/src/math/generic/acosf.cpp
+++ b/libc/src/math/generic/acosf.cpp
@@ -113,7 +113,7 @@ LLVM_LIBC_FUNCTION(float, acosf, (float x)) {
   xbits.set_sign(Sign::POS);
   double xd = static_cast<double>(xbits.get_val());
   double u = fputil::multiply_add(-0.5, xd, 0.5);
-  double cv = 2 * fputil::sqrt(u);
+  double cv = 2 * fputil::sqrt<double>(u);
 
   double r3 = asin_eval(u);
   double r = fputil::multiply_add(cv * u, r3, cv);
diff --git a/libc/src/math/generic/acoshf.cpp b/libc/src/math/generic/acoshf.cpp
index a4a75a7b04385..9422ec63e1ce2 100644
--- a/libc/src/math/generic/acoshf.cpp
+++ b/libc/src/math/generic/acoshf.cpp
@@ -66,8 +66,8 @@ LLVM_LIBC_FUNCTION(float, acoshf, (float x)) {
 
   double x_d = static_cast<double>(x);
   // acosh(x) = log(x + sqrt(x^2 - 1))
-  return static_cast<float>(
-      log_eval(x_d + fputil::sqrt(fputil::multiply_add(x_d, x_d, -1.0))));
+  return static_cast<float>(log_eval(
+      x_d + fputil::sqrt<double>(fputil::multiply_add(x_d, x_d, -1.0))));
 }
 
 } // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/asinf.cpp b/libc/src/math/generic/asinf.cpp
index d9133333d2561..c4afca493a713 100644
--- a/libc/src/math/generic/asinf.cpp
+++ b/libc/src/math/generic/asinf.cpp
@@ -144,7 +144,7 @@ LLVM_LIBC_FUNCTION(float, asinf, (float x)) {
   double sign = SIGN[x_sign];
   double xd = static_cast<double>(xbits.get_val());
   double u = fputil::multiply_add(-0.5, xd, 0.5);
-  double c1 = sign * (-2 * fputil::sqrt(u));
+  double c1 = sign * (-2 * fputil::sqrt<double>(u));
   double c2 = fputil::multiply_add(sign, M_MATH_PI_2, c1);
   double c3 = c1 * u;
 
diff --git a/libc/src/math/generic/asinhf.cpp b/libc/src/math/generic/asinhf.cpp
index 6e351786e3eca..82dc2a31ebc22 100644
--- a/libc/src/math/generic/asinhf.cpp
+++ b/libc/src/math/generic/asinhf.cpp
@@ -97,9 +97,9 @@ LLVM_LIBC_FUNCTION(float, asinhf, (float x)) {
 
   // asinh(x) = log(x + sqrt(x^2 + 1))
   return static_cast<float>(
-      x_sign *
-      log_eval(fputil::multiply_add(
-          x_d, x_sign, fputil::sqrt(fputil::multiply_add(x_d, x_d, 1.0)))));
+      x_sign * log_eval(fputil::multiply_add(
+                   x_d, x_sign,
+                   fputil::sqrt<double>(fputil::multiply_add(x_d, x_d, 1.0)))));
 }
 
 } // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/hypotf.cpp b/libc/src/math/generic/hypotf.cpp
index ffbf706aefaf6..b09d09ad7f9c9 100644
--- a/libc/src/math/generic/hypotf.cpp
+++ b/libc/src/math/generic/hypotf.cpp
@@ -42,7 +42,7 @@ LLVM_LIBC_FUNCTION(float, hypotf, (float x, float y)) {
   double err = (x_sq >= y_sq) ? (sum_sq - x_sq) - y_sq : (sum_sq - y_sq) - x_sq;
 
   // Take sqrt in double precision.
-  DoubleBits result(fputil::sqrt(sum_sq));
+  DoubleBits result(fputil::sqrt<double>(sum_sq));
 
   if (!DoubleBits(sum_sq).is_inf_or_nan()) {
     // Correct rounding.
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h
index 32488343b9d2e..805678b96c2ef 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.h
+++ b/libc/utils/MPFRWrapper/MPFRUtils.h
@@ -206,7 +206,7 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
   // This method is marked with NOLINT because the name `explainError` does not
   // conform to the coding style.
   void explainError() override { // NOLINT
-    explain_error<InputType, OutputType>(input, match_value);
+    explain_error(input, match_value);
   }
 
   // Whether the `explainError` step is skipped or not.
@@ -239,8 +239,7 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
                                                 rounding);
   }
 
-  template <typename T, typename U>
-  void explain_error(InputType in, OutputType out) {
+  template <typename T, typename U> void explain_error(T in, U out) {
     explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
                                                 rounding);
   }

>From 42f970b25f5efd27ad3dcfc6e7fb6a8fb758b183 Mon Sep 17 00:00:00 2001
From: OverMighty <its.overmighty at gmail.com>
Date: Wed, 12 Jun 2024 21:15:19 +0200
Subject: [PATCH 7/8] [libc][math] Fix build breakage of other smoke tests

---
 libc/src/math/generic/powf.cpp             | 2 +-
 libc/test/src/math/smoke/SqrtTest.h        | 6 +++++-
 libc/test/src/math/smoke/f16sqrtf_test.cpp | 2 +-
 libc/test/src/math/smoke/sqrt_test.cpp     | 2 +-
 libc/test/src/math/smoke/sqrtf128_test.cpp | 2 +-
 libc/test/src/math/smoke/sqrtf_test.cpp    | 2 +-
 libc/test/src/math/smoke/sqrtl_test.cpp    | 2 +-
 7 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/libc/src/math/generic/powf.cpp b/libc/src/math/generic/powf.cpp
index 59efc3f424c76..13c04240f59c2 100644
--- a/libc/src/math/generic/powf.cpp
+++ b/libc/src/math/generic/powf.cpp
@@ -562,7 +562,7 @@ LLVM_LIBC_FUNCTION(float, powf, (float x, float y)) {
       switch (y_u) {
       case 0x3f00'0000: // y = 0.5f
         // pow(x, 1/2) = sqrt(x)
-        return fputil::sqrt(x);
+        return fputil::sqrt<float>(x);
       case 0x3f80'0000: // y = 1.0f
         return x;
       case 0x4000'0000: // y = 2.0f
diff --git a/libc/test/src/math/smoke/SqrtTest.h b/libc/test/src/math/smoke/SqrtTest.h
index 7731518308fef..ce9f2f85b4604 100644
--- a/libc/test/src/math/smoke/SqrtTest.h
+++ b/libc/test/src/math/smoke/SqrtTest.h
@@ -31,6 +31,10 @@ class SqrtTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
   }
 };
 
-#define LIST_SQRT_TESTS(OutType, InType, func)                                 \
+#define LIST_SQRT_TESTS(T, func)                                               \
+  using LlvmLibcSqrtTest = SqrtTest<T, T>;                                     \
+  TEST_F(LlvmLibcSqrtTest, SpecialNumbers) { test_special_numbers(&func); }
+
+#define LIST_NARROWING_SQRT_TESTS(OutType, InType, func)                       \
   using LlvmLibcSqrtTest = SqrtTest<OutType, InType>;                          \
   TEST_F(LlvmLibcSqrtTest, SpecialNumbers) { test_special_numbers(&func); }
diff --git a/libc/test/src/math/smoke/f16sqrtf_test.cpp b/libc/test/src/math/smoke/f16sqrtf_test.cpp
index bf160ccd35c32..36231aeb4184d 100644
--- a/libc/test/src/math/smoke/f16sqrtf_test.cpp
+++ b/libc/test/src/math/smoke/f16sqrtf_test.cpp
@@ -10,4 +10,4 @@
 
 #include "src/math/f16sqrtf.h"
 
-LIST_SQRT_TESTS(float16, float, LIBC_NAMESPACE::f16sqrtf)
+LIST_NARROWING_SQRT_TESTS(float16, float, LIBC_NAMESPACE::f16sqrtf)
diff --git a/libc/test/src/math/smoke/sqrt_test.cpp b/libc/test/src/math/smoke/sqrt_test.cpp
index 8e2e25dee5440..1551b31d6f715 100644
--- a/libc/test/src/math/smoke/sqrt_test.cpp
+++ b/libc/test/src/math/smoke/sqrt_test.cpp
@@ -10,4 +10,4 @@
 
 #include "src/math/sqrt.h"
 
-LIST_SQRT_TESTS(double, double, LIBC_NAMESPACE::sqrt)
+LIST_SQRT_TESTS(double, LIBC_NAMESPACE::sqrt)
diff --git a/libc/test/src/math/smoke/sqrtf128_test.cpp b/libc/test/src/math/smoke/sqrtf128_test.cpp
index 599e8af80ff00..23397b0623ce5 100644
--- a/libc/test/src/math/smoke/sqrtf128_test.cpp
+++ b/libc/test/src/math/smoke/sqrtf128_test.cpp
@@ -10,4 +10,4 @@
 
 #include "src/math/sqrtf128.h"
 
-LIST_SQRT_TESTS(float128, float128, LIBC_NAMESPACE::sqrtf128)
+LIST_SQRT_TESTS(float128, LIBC_NAMESPACE::sqrtf128)
diff --git a/libc/test/src/math/smoke/sqrtf_test.cpp b/libc/test/src/math/smoke/sqrtf_test.cpp
index 13093efd24ae5..3f2e973325bd0 100644
--- a/libc/test/src/math/smoke/sqrtf_test.cpp
+++ b/libc/test/src/math/smoke/sqrtf_test.cpp
@@ -10,4 +10,4 @@
 
 #include "src/math/sqrtf.h"
 
-LIST_SQRT_TESTS(float, float, LIBC_NAMESPACE::sqrtf)
+LIST_SQRT_TESTS(float, LIBC_NAMESPACE::sqrtf)
diff --git a/libc/test/src/math/smoke/sqrtl_test.cpp b/libc/test/src/math/smoke/sqrtl_test.cpp
index f49daf0f90edb..f80bcfb736078 100644
--- a/libc/test/src/math/smoke/sqrtl_test.cpp
+++ b/libc/test/src/math/smoke/sqrtl_test.cpp
@@ -10,4 +10,4 @@
 
 #include "src/math/sqrtl.h"
 
-LIST_SQRT_TESTS(long double, long double, LIBC_NAMESPACE::sqrtl)
+LIST_SQRT_TESTS(long double, LIBC_NAMESPACE::sqrtl)

>From 1510d2f8f7f9d00c2a76036da8c3fddd2e9bd647 Mon Sep 17 00:00:00 2001
From: OverMighty <its.overmighty at gmail.com>
Date: Thu, 13 Jun 2024 14:52:41 +0200
Subject: [PATCH 8/8] [libc][docs] Sort f16sqrt before fsqrt in table

---
 libc/docs/math/index.rst | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/libc/docs/math/index.rst b/libc/docs/math/index.rst
index 8243b14ff4786..790786147c164 100644
--- a/libc/docs/math/index.rst
+++ b/libc/docs/math/index.rst
@@ -280,10 +280,10 @@ Higher Math Functions
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | fma       | |check|          | |check|         |                        |                      |                        | 7.12.13.1              | F.10.10.1                  |
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
-| fsqrt     | N/A              |                 |                        | N/A                  |                        | 7.12.14.6              | F.10.11                    |
-+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | f16sqrt   | |check|          |                 |                        | N/A                  |                        | 7.12.14.6              | F.10.11                    |
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
+| fsqrt     | N/A              |                 |                        | N/A                  |                        | 7.12.14.6              | F.10.11                    |
++-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | hypot     | |check|          | |check|         |                        |                      |                        | 7.12.7.4               | F.10.4.4                   |
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | lgamma    |                  |                 |                        |                      |                        | 7.12.8.3               | F.10.5.3                   |



More information about the libc-commits mailing list