[libc-commits] [libc] 4beba3a - [libc] Revert "Refactor sqrt implementations and add tests for generic sqrt implementations."

Siva Chandra Reddy via libc-commits libc-commits at lists.llvm.org
Thu Jan 27 13:11:29 PST 2022


Author: Siva Chandra Reddy
Date: 2022-01-27T21:06:14Z
New Revision: 4beba3a32a6537b80e88ea7c5e4f5a425599ca5d

URL: https://github.com/llvm/llvm-project/commit/4beba3a32a6537b80e88ea7c5e4f5a425599ca5d
DIFF: https://github.com/llvm/llvm-project/commit/4beba3a32a6537b80e88ea7c5e4f5a425599ca5d.diff

LOG: [libc] Revert "Refactor sqrt implementations and add tests for generic sqrt implementations."

This reverts commit 21c4c82c2026bac1f53be54923c0663d41d0a0aa.

Added: 
    libc/src/__support/FPUtil/Sqrt.h
    libc/src/__support/FPUtil/x86_64/SqrtLongDouble.h
    libc/src/math/x86_64/sqrt.cpp
    libc/src/math/x86_64/sqrtf.cpp
    libc/src/math/x86_64/sqrtl.cpp

Modified: 
    libc/src/__support/FPUtil/CMakeLists.txt
    libc/src/math/aarch64/CMakeLists.txt
    libc/src/math/generic/CMakeLists.txt
    libc/src/math/generic/sqrt.cpp
    libc/src/math/generic/sqrtf.cpp
    libc/src/math/generic/sqrtl.cpp
    libc/src/math/x86_64/CMakeLists.txt
    libc/test/src/math/CMakeLists.txt
    utils/bazel/llvm-project-overlay/libc/BUILD.bazel

Removed: 
    libc/src/__support/FPUtil/aarch64/sqrt.h
    libc/src/__support/FPUtil/generic/CMakeLists.txt
    libc/src/__support/FPUtil/generic/sqrt.h
    libc/src/__support/FPUtil/generic/sqrt_80_bit_long_double.h
    libc/src/__support/FPUtil/sqrt.h
    libc/src/__support/FPUtil/x86_64/sqrt.h
    libc/test/src/math/generic_sqrt_test.cpp
    libc/test/src/math/generic_sqrtf_test.cpp
    libc/test/src/math/generic_sqrtl_test.cpp


################################################################################
diff  --git a/libc/src/__support/FPUtil/CMakeLists.txt b/libc/src/__support/FPUtil/CMakeLists.txt
index d02cd9fcce0ec..6d005a9166c25 100644
--- a/libc/src/__support/FPUtil/CMakeLists.txt
+++ b/libc/src/__support/FPUtil/CMakeLists.txt
@@ -22,14 +22,3 @@ add_header_library(
     libc.src.__support.common
     libc.src.__support.CPP.standalone_cpp
 )
-
-add_header_library(
-  sqrt
-  HDRS
-    sqrt.h
-  DEPENDS
-    .fputil
-    libc.src.__support.FPUtil.generic.sqrt
-)
-
-add_subdirectory(generic)

diff  --git a/libc/src/__support/FPUtil/Sqrt.h b/libc/src/__support/FPUtil/Sqrt.h
new file mode 100644
index 0000000000000..652883ffc96b9
--- /dev/null
+++ b/libc/src/__support/FPUtil/Sqrt.h
@@ -0,0 +1,192 @@
+//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H
+#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H
+
+#include "FPBits.h"
+#include "PlatformDefs.h"
+
+#include "src/__support/CPP/TypeTraits.h"
+
+namespace __llvm_libc {
+namespace fputil {
+
+namespace internal {
+
+template <typename T>
+static inline void normalize(int &exponent,
+                             typename FPBits<T>::UIntType &mantissa);
+
+template <> inline void normalize<float>(int &exponent, uint32_t &mantissa) {
+  // Use binary search to shift the leading 1 bit.
+  // With MantissaWidth<float> = 23, it will take
+  // ceil(log2(23)) = 5 steps checking the mantissa bits as followed:
+  // Step 1: 0000 0000 0000 XXXX XXXX XXXX
+  // Step 2: 0000 00XX XXXX XXXX XXXX XXXX
+  // Step 3: 000X XXXX XXXX XXXX XXXX XXXX
+  // Step 4: 00XX XXXX XXXX XXXX XXXX XXXX
+  // Step 5: 0XXX XXXX XXXX XXXX XXXX XXXX
+  constexpr int NSTEPS = 5; // = ceil(log2(MantissaWidth))
+  constexpr uint32_t BOUNDS[NSTEPS] = {1 << 12, 1 << 18, 1 << 21, 1 << 22,
+                                       1 << 23};
+  constexpr int SHIFTS[NSTEPS] = {12, 6, 3, 2, 1};
+
+  for (int i = 0; i < NSTEPS; ++i) {
+    if (mantissa < BOUNDS[i]) {
+      exponent -= SHIFTS[i];
+      mantissa <<= SHIFTS[i];
+    }
+  }
+}
+
+template <> inline void normalize<double>(int &exponent, uint64_t &mantissa) {
+  // Use binary search to shift the leading 1 bit similar to float.
+  // With MantissaWidth<double> = 52, it will take
+  // ceil(log2(52)) = 6 steps checking the mantissa bits.
+  constexpr int NSTEPS = 6; // = ceil(log2(MantissaWidth))
+  constexpr uint64_t BOUNDS[NSTEPS] = {1ULL << 26, 1ULL << 39, 1ULL << 46,
+                                       1ULL << 49, 1ULL << 51, 1ULL << 52};
+  constexpr int SHIFTS[NSTEPS] = {27, 14, 7, 4, 2, 1};
+
+  for (int i = 0; i < NSTEPS; ++i) {
+    if (mantissa < BOUNDS[i]) {
+      exponent -= SHIFTS[i];
+      mantissa <<= SHIFTS[i];
+    }
+  }
+}
+
+#ifdef LONG_DOUBLE_IS_DOUBLE
+template <>
+inline void normalize<long double>(int &exponent, uint64_t &mantissa) {
+  normalize<double>(exponent, mantissa);
+}
+#elif !defined(SPECIAL_X86_LONG_DOUBLE)
+template <>
+inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
+  // Use binary search to shift the leading 1 bit similar to float.
+  // With MantissaWidth<long double> = 112, it will take
+  // ceil(log2(112)) = 7 steps checking the mantissa bits.
+  constexpr int NSTEPS = 7; // = ceil(log2(MantissaWidth))
+  constexpr __uint128_t BOUNDS[NSTEPS] = {
+      __uint128_t(1) << 56,  __uint128_t(1) << 84,  __uint128_t(1) << 98,
+      __uint128_t(1) << 105, __uint128_t(1) << 109, __uint128_t(1) << 111,
+      __uint128_t(1) << 112};
+  constexpr int SHIFTS[NSTEPS] = {57, 29, 15, 8, 4, 2, 1};
+
+  for (int i = 0; i < NSTEPS; ++i) {
+    if (mantissa < BOUNDS[i]) {
+      exponent -= SHIFTS[i];
+      mantissa <<= SHIFTS[i];
+    }
+  }
+}
+#endif
+
+} // namespace internal
+
+// Correctly rounded IEEE 754 SQRT with round to nearest, ties to even.
+// Shift-and-add algorithm.
+template <typename T,
+          cpp::EnableIfType<cpp::IsFloatingPointType<T>::Value, int> = 0>
+static inline T sqrt(T x) {
+  using UIntType = typename FPBits<T>::UIntType;
+  constexpr UIntType ONE = UIntType(1) << MantissaWidth<T>::VALUE;
+
+  FPBits<T> bits(x);
+
+  if (bits.is_inf_or_nan()) {
+    if (bits.get_sign() && (bits.get_mantissa() == 0)) {
+      // sqrt(-Inf) = NaN
+      return FPBits<T>::build_nan(ONE >> 1);
+    } else {
+      // sqrt(NaN) = NaN
+      // sqrt(+Inf) = +Inf
+      return x;
+    }
+  } else if (bits.is_zero()) {
+    // sqrt(+0) = +0
+    // sqrt(-0) = -0
+    return x;
+  } else if (bits.get_sign()) {
+    // sqrt( negative numbers ) = NaN
+    return FPBits<T>::build_nan(ONE >> 1);
+  } else {
+    int x_exp = bits.get_exponent();
+    UIntType x_mant = bits.get_mantissa();
+
+    // Step 1a: Normalize denormal input and append hidden bit to the mantissa
+    if (bits.get_unbiased_exponent() == 0) {
+      ++x_exp; // let x_exp be the correct exponent of ONE bit.
+      internal::normalize<T>(x_exp, x_mant);
+    } else {
+      x_mant |= ONE;
+    }
+
+    // Step 1b: Make sure the exponent is even.
+    if (x_exp & 1) {
+      --x_exp;
+      x_mant <<= 1;
+    }
+
+    // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and
+    // 1 <= x_mant < 4.  So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2.
+    // Notice that the output of sqrt is always in the normal range.
+    // To perform shift-and-add algorithm to find y, let denote:
+    //   y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
+    //   r(n) = 2^n ( x_mant - y(n)^2 ).
+    // That leads to the following recurrence formula:
+    //   r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
+    // with the initial conditions: y(0) = 1, and r(0) = x - 1.
+    // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
+    //   y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
+    //         0 otherwise.
+    UIntType y = ONE;
+    UIntType r = x_mant - ONE;
+
+    for (UIntType current_bit = ONE >> 1; current_bit; current_bit >>= 1) {
+      r <<= 1;
+      UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
+      if (r >= tmp) {
+        r -= tmp;
+        y += current_bit;
+      }
+    }
+
+    // We compute one more iteration in order to round correctly.
+    bool lsb = y & 1; // Least significant bit
+    bool rb = false;  // Round bit
+    r <<= 2;
+    UIntType tmp = (y << 2) + 1;
+    if (r >= tmp) {
+      r -= tmp;
+      rb = true;
+    }
+
+    // Remove hidden bit and append the exponent field.
+    x_exp = ((x_exp >> 1) + FPBits<T>::EXPONENT_BIAS);
+
+    y = (y - ONE) | (static_cast<UIntType>(x_exp) << MantissaWidth<T>::VALUE);
+    // Round to nearest, ties to even
+    if (rb && (lsb || (r != 0))) {
+      ++y;
+    }
+
+    return *reinterpret_cast<T *>(&y);
+  }
+}
+
+} // namespace fputil
+} // namespace __llvm_libc
+
+#ifdef SPECIAL_X86_LONG_DOUBLE
+#include "x86_64/SqrtLongDouble.h"
+#endif // SPECIAL_X86_LONG_DOUBLE
+
+#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H

diff  --git a/libc/src/__support/FPUtil/aarch64/sqrt.h b/libc/src/__support/FPUtil/aarch64/sqrt.h
deleted file mode 100644
index 479ebf76b6786..0000000000000
--- a/libc/src/__support/FPUtil/aarch64/sqrt.h
+++ /dev/null
@@ -1,38 +0,0 @@
-//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_AARCH64_SQRT_H
-#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_AARCH64_SQRT_H
-
-#include "src/__support/architectures.h"
-
-#if !defined(LLVM_LIBC_ARCH_AARCH64)
-#error "Invalid include"
-#endif
-
-#include "src/__support/FPUtil/generic/sqrt.h"
-
-namespace __llvm_libc {
-namespace fputil {
-
-template <> inline float sqrt<float>(float x) {
-  float y;
-  __asm__ __volatile__("fsqrt %s0, %s1\n\t" : "=w"(y) : "w"(x));
-  return y;
-}
-
-template <> inline double sqrt<double>(double x) {
-  double y;
-  __asm__ __volatile__("fsqrt %d0, %d1\n\t" : "=w"(y) : "w"(x));
-  return y;
-}
-
-} // namespace fputil
-} // namespace __llvm_libc
-
-#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_AARCH64_SQRT_H

diff  --git a/libc/src/__support/FPUtil/generic/CMakeLists.txt b/libc/src/__support/FPUtil/generic/CMakeLists.txt
deleted file mode 100644
index bf69e7dd961cd..0000000000000
--- a/libc/src/__support/FPUtil/generic/CMakeLists.txt
+++ /dev/null
@@ -1,6 +0,0 @@
-add_header_library(
-  sqrt
-  HDRS
-    sqrt.h
-    sqrt_80_bit_long_double.h
-)

diff  --git a/libc/src/__support/FPUtil/generic/sqrt.h b/libc/src/__support/FPUtil/generic/sqrt.h
deleted file mode 100644
index 92b18e297ae14..0000000000000
--- a/libc/src/__support/FPUtil/generic/sqrt.h
+++ /dev/null
@@ -1,215 +0,0 @@
-//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H
-#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H
-
-#include "sqrt_80_bit_long_double.h"
-#include "src/__support/CPP/TypeTraits.h"
-#include "src/__support/FPUtil/FEnvImpl.h"
-#include "src/__support/FPUtil/FPBits.h"
-#include "src/__support/FPUtil/PlatformDefs.h"
-
-namespace __llvm_libc {
-namespace fputil {
-
-namespace internal {
-
-#if defined(SPECIAL_X86_LONG_DOUBLE)
-struct SpecialLongDouble {
-  static constexpr bool VALUE = true;
-};
-#else
-struct SpecialLongDouble {
-  static constexpr bool VALUE = false;
-};
-#endif // SPECIAL_X86_LONG_DOUBLE
-
-template <typename T>
-static inline void normalize(int &exponent,
-                             typename FPBits<T>::UIntType &mantissa);
-
-template <> inline void normalize<float>(int &exponent, uint32_t &mantissa) {
-  // Use binary search to shift the leading 1 bit.
-  // With MantissaWidth<float> = 23, it will take
-  // ceil(log2(23)) = 5 steps checking the mantissa bits as followed:
-  // Step 1: 0000 0000 0000 XXXX XXXX XXXX
-  // Step 2: 0000 00XX XXXX XXXX XXXX XXXX
-  // Step 3: 000X XXXX XXXX XXXX XXXX XXXX
-  // Step 4: 00XX XXXX XXXX XXXX XXXX XXXX
-  // Step 5: 0XXX XXXX XXXX XXXX XXXX XXXX
-  constexpr int NSTEPS = 5; // = ceil(log2(MantissaWidth))
-  constexpr uint32_t BOUNDS[NSTEPS] = {1 << 12, 1 << 18, 1 << 21, 1 << 22,
-                                       1 << 23};
-  constexpr int SHIFTS[NSTEPS] = {12, 6, 3, 2, 1};
-
-  for (int i = 0; i < NSTEPS; ++i) {
-    if (mantissa < BOUNDS[i]) {
-      exponent -= SHIFTS[i];
-      mantissa <<= SHIFTS[i];
-    }
-  }
-}
-
-template <> inline void normalize<double>(int &exponent, uint64_t &mantissa) {
-  // Use binary search to shift the leading 1 bit similar to float.
-  // With MantissaWidth<double> = 52, it will take
-  // ceil(log2(52)) = 6 steps checking the mantissa bits.
-  constexpr int NSTEPS = 6; // = ceil(log2(MantissaWidth))
-  constexpr uint64_t BOUNDS[NSTEPS] = {1ULL << 26, 1ULL << 39, 1ULL << 46,
-                                       1ULL << 49, 1ULL << 51, 1ULL << 52};
-  constexpr int SHIFTS[NSTEPS] = {27, 14, 7, 4, 2, 1};
-
-  for (int i = 0; i < NSTEPS; ++i) {
-    if (mantissa < BOUNDS[i]) {
-      exponent -= SHIFTS[i];
-      mantissa <<= SHIFTS[i];
-    }
-  }
-}
-
-#ifdef LONG_DOUBLE_IS_DOUBLE
-template <>
-inline void normalize<long double>(int &exponent, uint64_t &mantissa) {
-  normalize<double>(exponent, mantissa);
-}
-#elif !defined(SPECIAL_X86_LONG_DOUBLE)
-template <>
-inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
-  // Use binary search to shift the leading 1 bit similar to float.
-  // With MantissaWidth<long double> = 112, it will take
-  // ceil(log2(112)) = 7 steps checking the mantissa bits.
-  constexpr int NSTEPS = 7; // = ceil(log2(MantissaWidth))
-  constexpr __uint128_t BOUNDS[NSTEPS] = {
-      __uint128_t(1) << 56,  __uint128_t(1) << 84,  __uint128_t(1) << 98,
-      __uint128_t(1) << 105, __uint128_t(1) << 109, __uint128_t(1) << 111,
-      __uint128_t(1) << 112};
-  constexpr int SHIFTS[NSTEPS] = {57, 29, 15, 8, 4, 2, 1};
-
-  for (int i = 0; i < NSTEPS; ++i) {
-    if (mantissa < BOUNDS[i]) {
-      exponent -= SHIFTS[i];
-      mantissa <<= SHIFTS[i];
-    }
-  }
-}
-#endif
-
-} // namespace internal
-
-// Correctly rounded IEEE 754 SQRT for all rounding modes.
-// Shift-and-add algorithm.
-template <typename T>
-static inline cpp::EnableIfType<cpp::IsFloatingPointType<T>::Value, T>
-sqrt(T x) {
-
-  if constexpr (cpp::IsSameV<T, long double> &&
-                internal::SpecialLongDouble::VALUE) {
-    // Special 80-bit long double.
-    return x86::sqrt(x);
-  } else {
-    // IEEE floating points formats.
-    using UIntType = typename FPBits<T>::UIntType;
-    constexpr UIntType ONE = UIntType(1) << MantissaWidth<T>::VALUE;
-
-    FPBits<T> bits(x);
-
-    if (bits.is_inf_or_nan()) {
-      if (bits.get_sign() && (bits.get_mantissa() == 0)) {
-        // sqrt(-Inf) = NaN
-        return FPBits<T>::build_nan(ONE >> 1);
-      } else {
-        // sqrt(NaN) = NaN
-        // sqrt(+Inf) = +Inf
-        return x;
-      }
-    } else if (bits.is_zero()) {
-      // sqrt(+0) = +0
-      // sqrt(-0) = -0
-      return x;
-    } else if (bits.get_sign()) {
-      // sqrt( negative numbers ) = NaN
-      return FPBits<T>::build_nan(ONE >> 1);
-    } else {
-      int x_exp = bits.get_exponent();
-      UIntType x_mant = bits.get_mantissa();
-
-      // Step 1a: Normalize denormal input and append hidden bit to the mantissa
-      if (bits.get_unbiased_exponent() == 0) {
-        ++x_exp; // let x_exp be the correct exponent of ONE bit.
-        internal::normalize<T>(x_exp, x_mant);
-      } else {
-        x_mant |= ONE;
-      }
-
-      // Step 1b: Make sure the exponent is even.
-      if (x_exp & 1) {
-        --x_exp;
-        x_mant <<= 1;
-      }
-
-      // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and
-      // 1 <= x_mant < 4.  So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2.
-      // Notice that the output of sqrt is always in the normal range.
-      // To perform shift-and-add algorithm to find y, let denote:
-      //   y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
-      //   r(n) = 2^n ( x_mant - y(n)^2 ).
-      // That leads to the following recurrence formula:
-      //   r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
-      // with the initial conditions: y(0) = 1, and r(0) = x - 1.
-      // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
-      //   y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
-      //         0 otherwise.
-      UIntType y = ONE;
-      UIntType r = x_mant - ONE;
-
-      for (UIntType current_bit = ONE >> 1; current_bit; current_bit >>= 1) {
-        r <<= 1;
-        UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
-        if (r >= tmp) {
-          r -= tmp;
-          y += current_bit;
-        }
-      }
-
-      // We compute one more iteration in order to round correctly.
-      bool lsb = y & 1; // Least significant bit
-      bool rb = false;  // Round bit
-      r <<= 2;
-      UIntType tmp = (y << 2) + 1;
-      if (r >= tmp) {
-        r -= tmp;
-        rb = true;
-      }
-
-      // Remove hidden bit and append the exponent field.
-      x_exp = ((x_exp >> 1) + FPBits<T>::EXPONENT_BIAS);
-
-      y = (y - ONE) | (static_cast<UIntType>(x_exp) << MantissaWidth<T>::VALUE);
-
-      switch (get_round()) {
-      case FE_TONEAREST:
-        // Round to nearest, ties to even
-        if (rb && (lsb || (r != 0)))
-          ++y;
-        break;
-      case FE_UPWARD:
-        if (rb || (r != 0))
-          ++y;
-        break;
-      }
-
-      return *reinterpret_cast<T *>(&y);
-    }
-  }
-}
-
-} // namespace fputil
-} // namespace __llvm_libc
-
-#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H

diff  --git a/libc/src/__support/FPUtil/sqrt.h b/libc/src/__support/FPUtil/sqrt.h
deleted file mode 100644
index 6e02d9c77e8e0..0000000000000
--- a/libc/src/__support/FPUtil/sqrt.h
+++ /dev/null
@@ -1,22 +0,0 @@
-//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H
-#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H
-
-#include "src/__support/architectures.h"
-
-#if defined(LLVM_LIBC_ARCH_X86_64)
-#include "x86_64/sqrt.h"
-#elif defined(LLVM_LIBC_ARCH_AARCH64)
-#include "aarch64/sqrt.h"
-#else
-#include "generic/sqrt.h"
-
-#endif
-#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H

diff  --git a/libc/src/__support/FPUtil/generic/sqrt_80_bit_long_double.h b/libc/src/__support/FPUtil/x86_64/SqrtLongDouble.h
similarity index 83%
rename from libc/src/__support/FPUtil/generic/sqrt_80_bit_long_double.h
rename to libc/src/__support/FPUtil/x86_64/SqrtLongDouble.h
index 6c9c6ab748a2e..22d2ba2592c8f 100644
--- a/libc/src/__support/FPUtil/generic/sqrt_80_bit_long_double.h
+++ b/libc/src/__support/FPUtil/x86_64/SqrtLongDouble.h
@@ -6,17 +6,26 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_80_BIT_LONG_DOUBLE_H
-#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_80_BIT_LONG_DOUBLE_H
+#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_LONG_DOUBLE_H
+#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_LONG_DOUBLE_H
 
-#include "src/__support/FPUtil/FEnvImpl.h"
+#include "src/__support/architectures.h"
+
+#if !defined(LLVM_LIBC_ARCH_X86)
+#error "Invalid include"
+#endif
+
+#include "src/__support/CPP/TypeTraits.h"
 #include "src/__support/FPUtil/FPBits.h"
+#include "src/__support/FPUtil/Sqrt.h"
 
 namespace __llvm_libc {
 namespace fputil {
-namespace x86 {
 
-inline void normalize(int &exponent, __uint128_t &mantissa) {
+namespace internal {
+
+template <>
+inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
   // Use binary search to shift the leading 1 bit similar to float.
   // With MantissaWidth<long double> = 63, it will take
   // ceil(log2(63)) = 6 steps checking the mantissa bits.
@@ -34,9 +43,11 @@ inline void normalize(int &exponent, __uint128_t &mantissa) {
   }
 }
 
-// Correctly rounded SQRT for all rounding modes.
+} // namespace internal
+
+// Correctly rounded SQRT with round to nearest, ties to even.
 // Shift-and-add algorithm.
-static inline long double sqrt(long double x) {
+template <> inline long double sqrt<long double, 0>(long double x) {
   using UIntType = typename FPBits<long double>::UIntType;
   constexpr UIntType ONE = UIntType(1)
                            << int(MantissaWidth<long double>::VALUE);
@@ -67,7 +78,7 @@ static inline long double sqrt(long double x) {
     if (bits.get_implicit_bit()) {
       x_mant |= ONE;
     } else if (bits.get_unbiased_exponent() == 0) {
-      normalize(x_exp, x_mant);
+      internal::normalize<long double>(x_exp, x_mant);
     }
 
     // Step 1b: Make sure the exponent is even.
@@ -115,16 +126,9 @@ static inline long double sqrt(long double x) {
     y |= (static_cast<UIntType>(x_exp)
           << (MantissaWidth<long double>::VALUE + 1));
 
-    switch (get_round()) {
-    case FE_TONEAREST:
-      // Round to nearest, ties to even
-      if (rb && (lsb || (r != 0)))
-        ++y;
-      break;
-    case FE_UPWARD:
-      if (rb || (r != 0))
-        ++y;
-      break;
+    // Round to nearest, ties to even
+    if (rb && (lsb || (r != 0))) {
+      ++y;
     }
 
     // Extract output
@@ -137,8 +141,7 @@ static inline long double sqrt(long double x) {
   }
 }
 
-} // namespace x86
 } // namespace fputil
 } // namespace __llvm_libc
 
-#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_80_BIT_LONG_DOUBLE_H
+#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_LONG_DOUBLE_H

diff  --git a/libc/src/__support/FPUtil/x86_64/sqrt.h b/libc/src/__support/FPUtil/x86_64/sqrt.h
deleted file mode 100644
index 8a8f8cf2238db..0000000000000
--- a/libc/src/__support/FPUtil/x86_64/sqrt.h
+++ /dev/null
@@ -1,44 +0,0 @@
-//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_H
-#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_H
-
-#include "src/__support/architectures.h"
-
-#if !defined(LLVM_LIBC_ARCH_X86)
-#error "Invalid include"
-#endif
-
-#include "src/__support/FPUtil/generic/sqrt.h"
-
-namespace __llvm_libc {
-namespace fputil {
-
-template <> inline float sqrt<float>(float x) {
-  float result;
-  __asm__ __volatile__("sqrtss %x1, %x0" : "=x"(result) : "x"(x));
-  return result;
-}
-
-template <> inline double sqrt<double>(double x) {
-  double result;
-  __asm__ __volatile__("sqrtsd %x1, %x0" : "=x"(result) : "x"(x));
-  return result;
-}
-
-template <> inline long double sqrt<long double>(long double x) {
-  long double result;
-  __asm__ __volatile__("fsqrt" : "=t"(result) : "t"(x));
-  return result;
-}
-
-} // namespace fputil
-} // namespace __llvm_libc
-
-#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_SQRT_H

diff  --git a/libc/src/math/aarch64/CMakeLists.txt b/libc/src/math/aarch64/CMakeLists.txt
index bbe927a1c7c88..6ce89441857ca 100644
--- a/libc/src/math/aarch64/CMakeLists.txt
+++ b/libc/src/math/aarch64/CMakeLists.txt
@@ -77,3 +77,23 @@ add_entrypoint_object(
   COMPILE_OPTIONS
     -O2
 )
+
+add_entrypoint_object(
+  sqrt
+  SRCS
+    sqrt.cpp
+  HDRS
+    ../sqrt.h
+  COMPILE_OPTIONS
+    -O2
+)
+
+add_entrypoint_object(
+  sqrtf
+  SRCS
+    sqrtf.cpp
+  HDRS
+    ../sqrtf.h
+  COMPILE_OPTIONS
+    -O2
+)

diff  --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt
index df2ef34d42ce9..88c29eb4ba0bb 100644
--- a/libc/src/math/generic/CMakeLists.txt
+++ b/libc/src/math/generic/CMakeLists.txt
@@ -859,10 +859,8 @@ add_entrypoint_object(
     ../sqrt.h
   DEPENDS
     libc.src.__support.FPUtil.fputil
-    libc.src.__support.FPUtil.sqrt
   COMPILE_OPTIONS
-    -O3
-    -Wno-c++17-extensions
+    -O2
 )
 
 add_entrypoint_object(
@@ -873,10 +871,8 @@ add_entrypoint_object(
     ../sqrtf.h
   DEPENDS
     libc.src.__support.FPUtil.fputil
-    libc.src.__support.FPUtil.sqrt
   COMPILE_OPTIONS
-    -O3
-    -Wno-c++17-extensions
+    -O2
 )
 
 add_entrypoint_object(
@@ -887,10 +883,8 @@ add_entrypoint_object(
     ../sqrtl.h
   DEPENDS
     libc.src.__support.FPUtil.fputil
-    libc.src.__support.FPUtil.sqrt
   COMPILE_OPTIONS
-    -O3
-    -Wno-c++17-extensions
+    -O2
 )
 
 add_entrypoint_object(

diff  --git a/libc/src/math/generic/sqrt.cpp b/libc/src/math/generic/sqrt.cpp
index de21f329e15ab..bd43a5c6919a4 100644
--- a/libc/src/math/generic/sqrt.cpp
+++ b/libc/src/math/generic/sqrt.cpp
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "src/math/sqrt.h"
-#include "src/__support/FPUtil/sqrt.h"
+#include "src/__support/FPUtil/Sqrt.h"
 #include "src/__support/common.h"
 
 namespace __llvm_libc {

diff  --git a/libc/src/math/generic/sqrtf.cpp b/libc/src/math/generic/sqrtf.cpp
index 3ca8d381898bb..bae39dd4b27e6 100644
--- a/libc/src/math/generic/sqrtf.cpp
+++ b/libc/src/math/generic/sqrtf.cpp
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "src/math/sqrtf.h"
-#include "src/__support/FPUtil/sqrt.h"
+#include "src/__support/FPUtil/Sqrt.h"
 #include "src/__support/common.h"
 
 namespace __llvm_libc {

diff  --git a/libc/src/math/generic/sqrtl.cpp b/libc/src/math/generic/sqrtl.cpp
index 970646a2e4d1a..efbc98eed8446 100644
--- a/libc/src/math/generic/sqrtl.cpp
+++ b/libc/src/math/generic/sqrtl.cpp
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "src/math/sqrtl.h"
-#include "src/__support/FPUtil/sqrt.h"
+#include "src/__support/FPUtil/Sqrt.h"
 #include "src/__support/common.h"
 
 namespace __llvm_libc {

diff  --git a/libc/src/math/x86_64/CMakeLists.txt b/libc/src/math/x86_64/CMakeLists.txt
index cd129e3eefb75..d2a48231b787e 100644
--- a/libc/src/math/x86_64/CMakeLists.txt
+++ b/libc/src/math/x86_64/CMakeLists.txt
@@ -27,3 +27,33 @@ add_entrypoint_object(
   COMPILE_OPTIONS
     -O2
 )
+
+add_entrypoint_object(
+  sqrt
+  SRCS
+    sqrt.cpp
+  HDRS
+    ../sqrt.h
+  COMPILE_OPTIONS
+    -O2
+)
+
+add_entrypoint_object(
+  sqrtf
+  SRCS
+    sqrtf.cpp
+  HDRS
+    ../sqrtf.h
+  COMPILE_OPTIONS
+    -O2
+)
+
+add_entrypoint_object(
+  sqrtl
+  SRCS
+    sqrtl.cpp
+  HDRS
+    ../sqrtl.h
+  COMPILE_OPTIONS
+    -O2
+)

diff  --git a/libc/src/math/x86_64/sqrt.cpp b/libc/src/math/x86_64/sqrt.cpp
new file mode 100644
index 0000000000000..5d4e9424e6030
--- /dev/null
+++ b/libc/src/math/x86_64/sqrt.cpp
@@ -0,0 +1,20 @@
+//===-- Implementation of the sqrt function for x86_64 --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/sqrt.h"
+#include "src/__support/common.h"
+
+namespace __llvm_libc {
+
+LLVM_LIBC_FUNCTION(double, sqrt, (double x)) {
+  double result;
+  __asm__ __volatile__("sqrtsd %x1, %x0" : "=x"(result) : "x"(x));
+  return result;
+}
+
+} // namespace __llvm_libc

diff  --git a/libc/src/math/x86_64/sqrtf.cpp b/libc/src/math/x86_64/sqrtf.cpp
new file mode 100644
index 0000000000000..51d22dff2cbcf
--- /dev/null
+++ b/libc/src/math/x86_64/sqrtf.cpp
@@ -0,0 +1,20 @@
+//===-- Implementation of the sqrtf function for x86_64 -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/sqrtf.h"
+#include "src/__support/common.h"
+
+namespace __llvm_libc {
+
+LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) {
+  float result;
+  __asm__ __volatile__("sqrtss %x1, %x0" : "=x"(result) : "x"(x));
+  return result;
+}
+
+} // namespace __llvm_libc

diff  --git a/libc/src/math/x86_64/sqrtl.cpp b/libc/src/math/x86_64/sqrtl.cpp
new file mode 100644
index 0000000000000..8b0c39e95fdd8
--- /dev/null
+++ b/libc/src/math/x86_64/sqrtl.cpp
@@ -0,0 +1,20 @@
+//===-- Implementation of the sqrtl function for x86_64 -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/sqrtl.h"
+#include "src/__support/common.h"
+
+namespace __llvm_libc {
+
+LLVM_LIBC_FUNCTION(long double, sqrtl, (long double x)) {
+  long double result;
+  __asm__ __volatile__("fsqrt" : "=t"(result) : "t"(x));
+  return result;
+}
+
+} // namespace __llvm_libc

diff  --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt
index 827dc867ff510..73ecef959aba0 100644
--- a/libc/test/src/math/CMakeLists.txt
+++ b/libc/test/src/math/CMakeLists.txt
@@ -983,63 +983,26 @@ add_fp_unittest(
     libc.src.__support.FPUtil.fputil
 )
 
-add_fp_unittest(
-  sqrtl_test
-  NEED_MPFR
-  SUITE
-    libc_math_unittests
-  SRCS
-    sqrtl_test.cpp
-  DEPENDS
-    libc.include.math
-    libc.src.math.sqrtl
-    libc.src.__support.FPUtil.fputil
-)
-
-add_fp_unittest(
-  generic_sqrtf_test
-  NEED_MPFR
-  SUITE
-    libc_math_unittests
-  SRCS
-    generic_sqrtf_test.cpp
-  DEPENDS
-    libc.src.__support.FPUtil.fputil
-    libc.src.__support.FPUtil.generic.sqrt
-  COMPILE_OPTIONS
-    -O3
-    -Wno-c++17-extensions
-)
-
-add_fp_unittest(
-  generic_sqrt_test
-  NEED_MPFR
-  SUITE
-    libc_math_unittests
-  SRCS
-    generic_sqrt_test.cpp
-  DEPENDS
-    libc.src.__support.FPUtil.fputil
-    libc.src.__support.FPUtil.generic.sqrt
-  COMPILE_OPTIONS
-    -O3
-    -Wno-c++17-extensions
-)
-
-add_fp_unittest(
-  generic_sqrtl_test
-  NEED_MPFR
-  SUITE
-    libc_math_unittests
-  SRCS
-    generic_sqrtl_test.cpp
-  DEPENDS
-    libc.src.__support.FPUtil.fputil
-    libc.src.__support.FPUtil.generic.sqrt
-  COMPILE_OPTIONS
-    -O3
-    -Wno-c++17-extensions
-)
+# The quad precision test for sqrt against MPFR currently suffers
+# from insufficient precision in MPFR calculations leading to
+# https://hal.archives-ouvertes.fr/hal-01091186/document. We will
+# renable after fixing the precision issue.
+if(${LIBC_TARGET_ARCHITECTURE_IS_X86})
+  add_fp_unittest(
+    sqrtl_test
+    NEED_MPFR
+    SUITE
+      libc_math_unittests
+    SRCS
+      sqrtl_test.cpp
+    DEPENDS
+      libc.include.math
+      libc.src.math.sqrtl
+      libc.src.__support.FPUtil.fputil
+  )
+else()
+  message(STATUS "Skipping sqrtl_test")
+endif()
 
 add_fp_unittest(
   remquof_test

diff  --git a/libc/test/src/math/generic_sqrt_test.cpp b/libc/test/src/math/generic_sqrt_test.cpp
deleted file mode 100644
index cecfc0ee3de3a..0000000000000
--- a/libc/test/src/math/generic_sqrt_test.cpp
+++ /dev/null
@@ -1,13 +0,0 @@
-//===-- Unittests for generic implementation of sqrt ----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "SqrtTest.h"
-
-#include "src/__support/FPUtil/generic/sqrt.h"
-
-LIST_SQRT_TESTS(double, __llvm_libc::fputil::sqrt<double>)

diff  --git a/libc/test/src/math/generic_sqrtf_test.cpp b/libc/test/src/math/generic_sqrtf_test.cpp
deleted file mode 100644
index 64bf92133b98f..0000000000000
--- a/libc/test/src/math/generic_sqrtf_test.cpp
+++ /dev/null
@@ -1,13 +0,0 @@
-//===-- Unittests for generic implementation of sqrtf----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "SqrtTest.h"
-
-#include "src/__support/FPUtil/generic/sqrt.h"
-
-LIST_SQRT_TESTS(float, __llvm_libc::fputil::sqrt<float>)

diff  --git a/libc/test/src/math/generic_sqrtl_test.cpp b/libc/test/src/math/generic_sqrtl_test.cpp
deleted file mode 100644
index 6b68aaed97004..0000000000000
--- a/libc/test/src/math/generic_sqrtl_test.cpp
+++ /dev/null
@@ -1,13 +0,0 @@
-//===-- Unittests for generic implementation of sqrtl----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "SqrtTest.h"
-
-#include "src/__support/FPUtil/generic/sqrt.h"
-
-LIST_SQRT_TESTS(long double, __llvm_libc::fputil::sqrt<long double>)

diff  --git a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
index 90f216cacb59c..b709565b14830 100644
--- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
@@ -80,6 +80,7 @@ fputil_common_hdrs = [
     "src/__support/FPUtil/NearestIntegerOperations.h",
     "src/__support/FPUtil/NormalFloat.h",
     "src/__support/FPUtil/PlatformDefs.h",
+    "src/__support/FPUtil/Sqrt.h",
 ]
 
 fputil_hdrs = selects.with_or({
@@ -87,6 +88,7 @@ fputil_hdrs = selects.with_or({
     PLATFORM_CPU_X86_64: fputil_common_hdrs + [
         "src/__support/FPUtil/x86_64/LongDoubleBits.h",
         "src/__support/FPUtil/x86_64/NextAfterLongDouble.h",
+        "src/__support/FPUtil/x86_64/SqrtLongDouble.h",
         "src/__support/FPUtil/x86_64/FEnvImpl.h",
     ],
     PLATFORM_CPU_ARM64: fputil_common_hdrs + [
@@ -104,31 +106,6 @@ cc_library(
     ],
 )
 
-sqrt_common_hdrs = [
-    "src/__support/FPUtil/sqrt.h",
-    "src/__support/FPUtil/generic/sqrt.h",
-    "src/__support/FPUtil/generic/sqrt_80_bit_long_double.h",
-]
-
-sqrt_hdrs = selects.with_or({
-    "//conditions:default": sqrt_common_hdrs,
-    PLATFORM_CPU_X86_64: sqrt_common_hdrs + [
-        "src/__support/FPUtil/x86_64/sqrt.h",
-    ],
-    PLATFORM_CPU_ARM64: sqrt_common_hdrs + [
-        "src/__support/FPUtil/aarch64/sqrt.h",
-    ],
-})
-
-cc_library(
-    name = "__support_fputil_sqrt",
-    hdrs = sqrt_hdrs,
-    deps = [
-        ":__support_fputil",
-        ":libc_root",
-    ],
-)
-
 ################################ fenv targets ################################
 
 libc_function(
@@ -461,23 +438,28 @@ libc_math_function(
 
 libc_math_function(
     name = "sqrt",
-    additional_deps = [
-        ":__support_fputil_sqrt",
-    ]
+    specializations = [
+        "aarch64",
+        "generic",
+        "x86_64",
+    ],
 )
 
 libc_math_function(
     name = "sqrtf",
-    additional_deps = [
-        ":__support_fputil_sqrt",
-    ]
+    specializations = [
+        "aarch64",
+        "generic",
+        "x86_64",
+    ],
 )
 
 libc_math_function(
     name = "sqrtl",
-    additional_deps = [
-        ":__support_fputil_sqrt",
-    ]
+    specializations = [
+        "generic",
+        "x86_64",
+    ],
 )
 
 libc_math_function(name = "copysign")


        


More information about the libc-commits mailing list