[libc-commits] [libc] 5078825 - [libc] Add implementations for sqrt, sqrtf, and sqrtl.

Tue Ly via libc-commits libc-commits at lists.llvm.org
Wed Aug 26 06:46:49 PDT 2020


Author: Tue Ly
Date: 2020-08-26T09:46:18-04:00
New Revision: 5078825aa982905088502f14b5387fc5c96017fe

URL: https://github.com/llvm/llvm-project/commit/5078825aa982905088502f14b5387fc5c96017fe
DIFF: https://github.com/llvm/llvm-project/commit/5078825aa982905088502f14b5387fc5c96017fe.diff

LOG: [libc] Add implementations for sqrt, sqrtf, and sqrtl.

Differential Revision: https://reviews.llvm.org/D84726

Added: 
    libc/src/math/sqrt.cpp
    libc/src/math/sqrt.h
    libc/src/math/sqrtf.cpp
    libc/src/math/sqrtf.h
    libc/src/math/sqrtl.cpp
    libc/src/math/sqrtl.h
    libc/test/src/math/sqrt_test.cpp
    libc/test/src/math/sqrtf_test.cpp
    libc/test/src/math/sqrtl_test.cpp
    libc/utils/FPUtil/Sqrt.h
    libc/utils/FPUtil/SqrtLongDoubleX86.h

Modified: 
    libc/config/linux/aarch64/entrypoints.txt
    libc/config/linux/api.td
    libc/config/linux/x86_64/entrypoints.txt
    libc/spec/stdc.td
    libc/src/math/CMakeLists.txt
    libc/test/src/math/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt
index fe63403ae221..34d07c24505d 100644
--- a/libc/config/linux/aarch64/entrypoints.txt
+++ b/libc/config/linux/aarch64/entrypoints.txt
@@ -75,6 +75,9 @@ set(TARGET_LIBM_ENTRYPOINTS
     libc.src.math.roundl
     libc.src.math.sincosf
     libc.src.math.sinf
+    libc.src.math.sqrt
+    libc.src.math.sqrtf
+    libc.src.math.sqrtl
     libc.src.math.trunc
     libc.src.math.truncf
     libc.src.math.truncl

diff  --git a/libc/config/linux/api.td b/libc/config/linux/api.td
index 6b50c4284ae2..063fe401da8b 100644
--- a/libc/config/linux/api.td
+++ b/libc/config/linux/api.td
@@ -204,6 +204,9 @@ def MathAPI : PublicAPI<"math.h"> {
    "roundl",
    "sincosf",
    "sinf",
+   "sqrt",
+   "sqrtf",
+   "sqrtl",
    "trunc",
    "truncf",
    "truncl",

diff  --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt
index 35ca8bbeaf68..c24173b1d0e7 100644
--- a/libc/config/linux/x86_64/entrypoints.txt
+++ b/libc/config/linux/x86_64/entrypoints.txt
@@ -108,6 +108,9 @@ set(TARGET_LIBM_ENTRYPOINTS
     libc.src.math.roundl
     libc.src.math.sincosf
     libc.src.math.sinf
+    libc.src.math.sqrt
+    libc.src.math.sqrtf
+    libc.src.math.sqrtl
     libc.src.math.trunc
     libc.src.math.truncf
     libc.src.math.truncl

diff  --git a/libc/spec/stdc.td b/libc/spec/stdc.td
index ac240ff9576e..15fc12d375e6 100644
--- a/libc/spec/stdc.td
+++ b/libc/spec/stdc.td
@@ -314,6 +314,10 @@ def StdC : StandardSpec<"stdc"> {
           FunctionSpec<"roundf", RetValSpec<FloatType>, [ArgSpec<FloatType>]>,
           FunctionSpec<"roundl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>]>,
 
+          FunctionSpec<"sqrt", RetValSpec<DoubleType>, [ArgSpec<DoubleType>]>,
+          FunctionSpec<"sqrtf", RetValSpec<FloatType>, [ArgSpec<FloatType>]>,
+          FunctionSpec<"sqrtl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>]>,
+
           FunctionSpec<"trunc", RetValSpec<DoubleType>, [ArgSpec<DoubleType>]>,
           FunctionSpec<"truncf", RetValSpec<FloatType>, [ArgSpec<FloatType>]>,
           FunctionSpec<"truncl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>]>,

diff  --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt
index da18aeba9a2a..0c878de2ac95 100644
--- a/libc/src/math/CMakeLists.txt
+++ b/libc/src/math/CMakeLists.txt
@@ -485,3 +485,39 @@ add_entrypoint_object(
   COMPILE_OPTIONS
     -O2
 )
+
+add_entrypoint_object(
+  sqrt
+  SRCS
+    sqrt.cpp
+  HDRS
+    sqrt.h
+  DEPENDS
+    libc.utils.FPUtil.fputil
+  COMPILE_OPTIONS
+    -O2
+)
+
+add_entrypoint_object(
+  sqrtf
+  SRCS
+    sqrtf.cpp
+  HDRS
+    sqrtf.h
+  DEPENDS
+    libc.utils.FPUtil.fputil
+  COMPILE_OPTIONS
+    -O2
+)
+
+add_entrypoint_object(
+  sqrtl
+  SRCS
+    sqrtl.cpp
+  HDRS
+    sqrtl.h
+  DEPENDS
+    libc.utils.FPUtil.fputil
+  COMPILE_OPTIONS
+    -O2
+)

diff  --git a/libc/src/math/sqrt.cpp b/libc/src/math/sqrt.cpp
new file mode 100644
index 000000000000..32d38e61463d
--- /dev/null
+++ b/libc/src/math/sqrt.cpp
@@ -0,0 +1,16 @@
+//===-- Implementation of sqrt function -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "utils/FPUtil/Sqrt.h"
+#include "src/__support/common.h"
+
+namespace __llvm_libc {
+
+double LLVM_LIBC_ENTRYPOINT(sqrt)(double x) { return fputil::sqrt(x); }
+
+} // namespace __llvm_libc

diff  --git a/libc/src/math/sqrt.h b/libc/src/math/sqrt.h
new file mode 100644
index 000000000000..2390e07b5dce
--- /dev/null
+++ b/libc/src/math/sqrt.h
@@ -0,0 +1,18 @@
+//===-- Implementation header for sqrt --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_SQRT_H
+#define LLVM_LIBC_SRC_MATH_SQRT_H
+
+namespace __llvm_libc {
+
+double sqrt(double x);
+
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_SRC_MATH_SQRT_H

diff  --git a/libc/src/math/sqrtf.cpp b/libc/src/math/sqrtf.cpp
new file mode 100644
index 000000000000..391fa6a3281a
--- /dev/null
+++ b/libc/src/math/sqrtf.cpp
@@ -0,0 +1,16 @@
+//===-- Implementation of sqrtf function ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/__support/common.h"
+#include "utils/FPUtil/Sqrt.h"
+
+namespace __llvm_libc {
+
+float LLVM_LIBC_ENTRYPOINT(sqrtf)(float x) { return fputil::sqrt(x); }
+
+} // namespace __llvm_libc

diff  --git a/libc/src/math/sqrtf.h b/libc/src/math/sqrtf.h
new file mode 100644
index 000000000000..d1d06f3adfa8
--- /dev/null
+++ b/libc/src/math/sqrtf.h
@@ -0,0 +1,18 @@
+//===-- Implementation header for sqrtf -------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_SQRTF_H
+#define LLVM_LIBC_SRC_MATH_SQRTF_H
+
+namespace __llvm_libc {
+
+float sqrtf(float x);
+
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_SRC_MATH_SQRTF_H

diff  --git a/libc/src/math/sqrtl.cpp b/libc/src/math/sqrtl.cpp
new file mode 100644
index 000000000000..16450349d23a
--- /dev/null
+++ b/libc/src/math/sqrtl.cpp
@@ -0,0 +1,18 @@
+//===-- Implementation of sqrtl function ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/__support/common.h"
+#include "utils/FPUtil/Sqrt.h"
+
+namespace __llvm_libc {
+
+long double LLVM_LIBC_ENTRYPOINT(sqrtl)(long double x) {
+  return fputil::sqrt(x);
+}
+
+} // namespace __llvm_libc

diff  --git a/libc/src/math/sqrtl.h b/libc/src/math/sqrtl.h
new file mode 100644
index 000000000000..5fbfa1450714
--- /dev/null
+++ b/libc/src/math/sqrtl.h
@@ -0,0 +1,18 @@
+//===-- Implementation header for sqrtl -------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_SQRTL_H
+#define LLVM_LIBC_SRC_MATH_SQRTL_H
+
+namespace __llvm_libc {
+
+long double sqrtl(long double x);
+
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_SRC_MATH_SQRTL_H

diff  --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt
index 2fe766a2ffc6..07b505207452 100644
--- a/libc/test/src/math/CMakeLists.txt
+++ b/libc/test/src/math/CMakeLists.txt
@@ -513,3 +513,42 @@ add_fp_unittest(
     libc.src.math.fmaxl
     libc.utils.FPUtil.fputil
 )
+
+add_fp_unittest(
+  sqrtf_test
+  NEED_MPFR
+  SUITE
+    libc_math_unittests
+  SRCS
+    sqrtf_test.cpp
+  DEPENDS
+    libc.include.math
+    libc.src.math.sqrtf
+    libc.utils.FPUtil.fputil
+)
+
+add_fp_unittest(
+  sqrt_test
+  NEED_MPFR
+  SUITE
+    libc_math_unittests
+  SRCS
+    sqrt_test.cpp
+  DEPENDS
+    libc.include.math
+    libc.src.math.sqrt
+    libc.utils.FPUtil.fputil
+)
+
+add_fp_unittest(
+  sqrtl_test
+  NEED_MPFR
+  SUITE
+    libc_math_unittests
+  SRCS
+    sqrtl_test.cpp
+  DEPENDS
+    libc.include.math
+    libc.src.math.sqrtl
+    libc.utils.FPUtil.fputil
+)

diff  --git a/libc/test/src/math/sqrt_test.cpp b/libc/test/src/math/sqrt_test.cpp
new file mode 100644
index 000000000000..7ff4978ec9e3
--- /dev/null
+++ b/libc/test/src/math/sqrt_test.cpp
@@ -0,0 +1,67 @@
+//===-- Unittests for sqrt -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#include "include/math.h"
+#include "src/math/sqrt.h"
+#include "utils/FPUtil/FPBits.h"
+#include "utils/FPUtil/TestHelpers.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+using FPBits = __llvm_libc::fputil::FPBits<double>;
+using UIntType = typename FPBits::UIntType;
+
+namespace mpfr = __llvm_libc::testing::mpfr;
+
+constexpr UIntType HiddenBit =
+    UIntType(1) << __llvm_libc::fputil::MantissaWidth<double>::value;
+
+double nan = FPBits::buildNaN(1);
+double inf = FPBits::inf();
+double negInf = FPBits::negInf();
+
+TEST(SqrtTest, SpecialValues) {
+  ASSERT_FP_EQ(nan, __llvm_libc::sqrt(nan));
+  ASSERT_FP_EQ(inf, __llvm_libc::sqrt(inf));
+  ASSERT_FP_EQ(nan, __llvm_libc::sqrt(negInf));
+  ASSERT_FP_EQ(0.0, __llvm_libc::sqrt(0.0));
+  ASSERT_FP_EQ(-0.0, __llvm_libc::sqrt(-0.0));
+  ASSERT_FP_EQ(nan, __llvm_libc::sqrt(-1.0));
+  ASSERT_FP_EQ(1.0, __llvm_libc::sqrt(1.0));
+  ASSERT_FP_EQ(2.0, __llvm_libc::sqrt(4.0));
+  ASSERT_FP_EQ(3.0, __llvm_libc::sqrt(9.0));
+}
+
+TEST(SqrtTest, DenormalValues) {
+  for (UIntType mant = 1; mant < HiddenBit; mant <<= 1) {
+    FPBits denormal(0.0);
+    denormal.mantissa = mant;
+
+    ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, double(denormal),
+                      __llvm_libc::sqrt(denormal), 0.5);
+  }
+
+  constexpr UIntType count = 1'000'001;
+  constexpr UIntType step = HiddenBit / count;
+  for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+    double x = *reinterpret_cast<double *>(&v);
+    ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrt(x), 0.5);
+  }
+}
+
+TEST(SqrtTest, InDoubleRange) {
+  constexpr UIntType count = 10'000'001;
+  constexpr UIntType step = UIntType(-1) / count;
+  for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+    double x = *reinterpret_cast<double *>(&v);
+    if (isnan(x) || (x < 0)) {
+      continue;
+    }
+
+    ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrt(x), 0.5);
+  }
+}

diff  --git a/libc/test/src/math/sqrtf_test.cpp b/libc/test/src/math/sqrtf_test.cpp
new file mode 100644
index 000000000000..8c429065bb45
--- /dev/null
+++ b/libc/test/src/math/sqrtf_test.cpp
@@ -0,0 +1,67 @@
+//===-- Unittests for sqrtf -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#include "include/math.h"
+#include "src/math/sqrtf.h"
+#include "utils/FPUtil/FPBits.h"
+#include "utils/FPUtil/TestHelpers.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+using FPBits = __llvm_libc::fputil::FPBits<float>;
+using UIntType = typename FPBits::UIntType;
+
+namespace mpfr = __llvm_libc::testing::mpfr;
+
+constexpr UIntType HiddenBit =
+    UIntType(1) << __llvm_libc::fputil::MantissaWidth<float>::value;
+
+float nan = FPBits::buildNaN(1);
+float inf = FPBits::inf();
+float negInf = FPBits::negInf();
+
+TEST(SqrtfTest, SpecialValues) {
+  ASSERT_FP_EQ(nan, __llvm_libc::sqrtf(nan));
+  ASSERT_FP_EQ(inf, __llvm_libc::sqrtf(inf));
+  ASSERT_FP_EQ(nan, __llvm_libc::sqrtf(negInf));
+  ASSERT_FP_EQ(0.0f, __llvm_libc::sqrtf(0.0f));
+  ASSERT_FP_EQ(-0.0f, __llvm_libc::sqrtf(-0.0f));
+  ASSERT_FP_EQ(nan, __llvm_libc::sqrtf(-1.0f));
+  ASSERT_FP_EQ(1.0f, __llvm_libc::sqrtf(1.0f));
+  ASSERT_FP_EQ(2.0f, __llvm_libc::sqrtf(4.0f));
+  ASSERT_FP_EQ(3.0f, __llvm_libc::sqrtf(9.0f));
+}
+
+TEST(SqrtfTest, DenormalValues) {
+  for (UIntType mant = 1; mant < HiddenBit; mant <<= 1) {
+    FPBits denormal(0.0f);
+    denormal.mantissa = mant;
+
+    ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, float(denormal),
+                      __llvm_libc::sqrtf(denormal), 0.5);
+  }
+
+  constexpr UIntType count = 1'000'001;
+  constexpr UIntType step = HiddenBit / count;
+  for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+    float x = *reinterpret_cast<float *>(&v);
+    ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtf(x), 0.5);
+  }
+}
+
+TEST(SqrtfTest, InFloatRange) {
+  constexpr UIntType count = 10'000'001;
+  constexpr UIntType step = UIntType(-1) / count;
+  for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+    float x = *reinterpret_cast<float *>(&v);
+    if (isnan(x) || (x < 0)) {
+      continue;
+    }
+
+    ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtf(x), 0.5);
+  }
+}

diff  --git a/libc/test/src/math/sqrtl_test.cpp b/libc/test/src/math/sqrtl_test.cpp
new file mode 100644
index 000000000000..1fab3b2567e5
--- /dev/null
+++ b/libc/test/src/math/sqrtl_test.cpp
@@ -0,0 +1,67 @@
+//===-- Unittests for sqrtl ----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#include "include/math.h"
+#include "src/math/sqrtl.h"
+#include "utils/FPUtil/FPBits.h"
+#include "utils/FPUtil/TestHelpers.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+using FPBits = __llvm_libc::fputil::FPBits<long double>;
+using UIntType = typename FPBits::UIntType;
+
+namespace mpfr = __llvm_libc::testing::mpfr;
+
+constexpr UIntType HiddenBit =
+    UIntType(1) << __llvm_libc::fputil::MantissaWidth<long double>::value;
+
+long double nan = FPBits::buildNaN(1);
+long double inf = FPBits::inf();
+long double negInf = FPBits::negInf();
+
+TEST(SqrtlTest, SpecialValues) {
+  ASSERT_FP_EQ(nan, __llvm_libc::sqrtl(nan));
+  ASSERT_FP_EQ(inf, __llvm_libc::sqrtl(inf));
+  ASSERT_FP_EQ(nan, __llvm_libc::sqrtl(negInf));
+  ASSERT_FP_EQ(0.0L, __llvm_libc::sqrtl(0.0L));
+  ASSERT_FP_EQ(-0.0L, __llvm_libc::sqrtl(-0.0L));
+  ASSERT_FP_EQ(nan, __llvm_libc::sqrtl(-1.0L));
+  ASSERT_FP_EQ(1.0L, __llvm_libc::sqrtl(1.0L));
+  ASSERT_FP_EQ(2.0L, __llvm_libc::sqrtl(4.0L));
+  ASSERT_FP_EQ(3.0L, __llvm_libc::sqrtl(9.0L));
+}
+
+TEST(SqrtlTest, DenormalValues) {
+  for (UIntType mant = 1; mant < HiddenBit; mant <<= 1) {
+    FPBits denormal(0.0L);
+    denormal.mantissa = mant;
+
+    ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, static_cast<long double>(denormal),
+                      __llvm_libc::sqrtl(denormal), 0.5);
+  }
+
+  constexpr UIntType count = 1'000'001;
+  constexpr UIntType step = HiddenBit / count;
+  for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+    long double x = *reinterpret_cast<long double *>(&v);
+    ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtl(x), 0.5);
+  }
+}
+
+TEST(SqrtlTest, InLongDoubleRange) {
+  constexpr UIntType count = 10'000'001;
+  constexpr UIntType step = UIntType(-1) / count;
+  for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+    long double x = *reinterpret_cast<long double *>(&v);
+    if (isnan(x) || (x < 0)) {
+      continue;
+    }
+
+    ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtl(x), 0.5);
+  }
+}

diff  --git a/libc/utils/FPUtil/Sqrt.h b/libc/utils/FPUtil/Sqrt.h
new file mode 100644
index 000000000000..a12cc42fa340
--- /dev/null
+++ b/libc/utils/FPUtil/Sqrt.h
@@ -0,0 +1,186 @@
+//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_UTILS_FPUTIL_SQRT_H
+#define LLVM_LIBC_UTILS_FPUTIL_SQRT_H
+
+#include "FPBits.h"
+
+#include "utils/CPP/TypeTraits.h"
+
+namespace __llvm_libc {
+namespace fputil {
+
+namespace internal {
+
+template <typename T>
+static inline void normalize(int &exponent,
+                             typename FPBits<T>::UIntType &mantissa);
+
+template <> inline void normalize<float>(int &exponent, uint32_t &mantissa) {
+  // Use binary search to shift the leading 1 bit.
+  // With MantissaWidth<float> = 23, it will take
+  // ceil(log2(23)) = 5 steps checking the mantissa bits as followed:
+  // Step 1: 0000 0000 0000 XXXX XXXX XXXX
+  // Step 2: 0000 00XX XXXX XXXX XXXX XXXX
+  // Step 3: 000X XXXX XXXX XXXX XXXX XXXX
+  // Step 4: 00XX XXXX XXXX XXXX XXXX XXXX
+  // Step 5: 0XXX XXXX XXXX XXXX XXXX XXXX
+  constexpr int nsteps = 5; // = ceil(log2(MantissaWidth))
+  constexpr uint32_t bounds[nsteps] = {1 << 12, 1 << 18, 1 << 21, 1 << 22,
+                                       1 << 23};
+  constexpr int shifts[nsteps] = {12, 6, 3, 2, 1};
+
+  for (int i = 0; i < nsteps; ++i) {
+    if (mantissa < bounds[i]) {
+      exponent -= shifts[i];
+      mantissa <<= shifts[i];
+    }
+  }
+}
+
+template <> inline void normalize<double>(int &exponent, uint64_t &mantissa) {
+  // Use binary search to shift the leading 1 bit similar to float.
+  // With MantissaWidth<double> = 52, it will take
+  // ceil(log2(52)) = 6 steps checking the mantissa bits.
+  constexpr int nsteps = 6; // = ceil(log2(MantissaWidth))
+  constexpr uint64_t bounds[nsteps] = {1ULL << 26, 1ULL << 39, 1ULL << 46,
+                                       1ULL << 49, 1ULL << 51, 1ULL << 52};
+  constexpr int shifts[nsteps] = {27, 14, 7, 4, 2, 1};
+
+  for (int i = 0; i < nsteps; ++i) {
+    if (mantissa < bounds[i]) {
+      exponent -= shifts[i];
+      mantissa <<= shifts[i];
+    }
+  }
+}
+
+#if !(defined(__x86_64__) || defined(__i386__))
+template <>
+inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
+  // Use binary search to shift the leading 1 bit similar to float.
+  // With MantissaWidth<long double> = 112, it will take
+  // ceil(log2(112)) = 7 steps checking the mantissa bits.
+  constexpr int nsteps = 7; // = ceil(log2(MantissaWidth))
+  constexpr __uint128_t bounds[nsteps] = {
+      __uint128_t(1) << 56,  __uint128_t(1) << 84,  __uint128_t(1) << 98,
+      __uint128_t(1) << 105, __uint128_t(1) << 109, __uint128_t(1) << 111,
+      __uint128_t(1) << 112};
+  constexpr int shifts[nsteps] = {57, 29, 15, 8, 4, 2, 1};
+
+  for (int i = 0; i < nsteps; ++i) {
+    if (mantissa < bounds[i]) {
+      exponent -= shifts[i];
+      mantissa <<= shifts[i];
+    }
+  }
+}
+#endif
+
+} // namespace internal
+
+// Correctly rounded IEEE 754 SQRT with round to nearest, ties to even.
+// Shift-and-add algorithm.
+template <typename T,
+          cpp::EnableIfType<cpp::IsFloatingPointType<T>::Value, int> = 0>
+static inline T sqrt(T x) {
+  using UIntType = typename FPBits<T>::UIntType;
+  constexpr UIntType One = UIntType(1) << MantissaWidth<T>::value;
+
+  FPBits<T> bits(x);
+
+  if (bits.isInfOrNaN()) {
+    if (bits.sign && (bits.mantissa == 0)) {
+      // sqrt(-Inf) = NaN
+      return FPBits<T>::buildNaN(One >> 1);
+    } else {
+      // sqrt(NaN) = NaN
+      // sqrt(+Inf) = +Inf
+      return x;
+    }
+  } else if (bits.isZero()) {
+    // sqrt(+0) = +0
+    // sqrt(-0) = -0
+    return x;
+  } else if (bits.sign) {
+    // sqrt( negative numbers ) = NaN
+    return FPBits<T>::buildNaN(One >> 1);
+  } else {
+    int xExp = bits.getExponent();
+    UIntType xMant = bits.mantissa;
+
+    // Step 1a: Normalize denormal input and append hiddent bit to the mantissa
+    if (bits.exponent == 0) {
+      ++xExp; // let xExp be the correct exponent of One bit.
+      internal::normalize<T>(xExp, xMant);
+    } else {
+      xMant |= One;
+    }
+
+    // Step 1b: Make sure the exponent is even.
+    if (xExp & 1) {
+      --xExp;
+      xMant <<= 1;
+    }
+
+    // After step 1b, x = 2^(xExp) * xMant, where xExp is even, and
+    // 1 <= xMant < 4.  So sqrt(x) = 2^(xExp / 2) * y, with 1 <= y < 2.
+    // Notice that the output of sqrt is always in the normal range.
+    // To perform shift-and-add algorithm to find y, let denote:
+    //   y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
+    //   r(n) = 2^n ( xMant - y(n)^2 ).
+    // That leads to the following recurrence formula:
+    //   r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
+    // with the initial conditions: y(0) = 1, and r(0) = x - 1.
+    // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
+    //   y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
+    //         0 otherwise.
+    UIntType y = One;
+    UIntType r = xMant - One;
+
+    for (UIntType current_bit = One >> 1; current_bit; current_bit >>= 1) {
+      r <<= 1;
+      UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
+      if (r >= tmp) {
+        r -= tmp;
+        y += current_bit;
+      }
+    }
+
+    // We compute one more iteration in order to round correctly.
+    bool lsb = y & 1; // Least significant bit
+    bool rb = false;  // Round bit
+    r <<= 2;
+    UIntType tmp = (y << 2) + 1;
+    if (r >= tmp) {
+      r -= tmp;
+      rb = true;
+    }
+
+    // Remove hidden bit and append the exponent field.
+    xExp = ((xExp >> 1) + FPBits<T>::exponentBias);
+
+    y = (y - One) | (static_cast<UIntType>(xExp) << MantissaWidth<T>::value);
+    // Round to nearest, ties to even
+    if (rb && (lsb || (r != 0))) {
+      ++y;
+    }
+
+    return *reinterpret_cast<T *>(&y);
+  }
+}
+
+} // namespace fputil
+} // namespace __llvm_libc
+
+#if (defined(__x86_64__) || defined(__i386__))
+#include "SqrtLongDoubleX86.h"
+#endif // defined(__x86_64__) || defined(__i386__)
+
+#endif // LLVM_LIBC_UTILS_FPUTIL_SQRT_H

diff  --git a/libc/utils/FPUtil/SqrtLongDoubleX86.h b/libc/utils/FPUtil/SqrtLongDoubleX86.h
new file mode 100644
index 000000000000..2ac73044cf92
--- /dev/null
+++ b/libc/utils/FPUtil/SqrtLongDoubleX86.h
@@ -0,0 +1,142 @@
+//===-- Square root of x86 long double numbers ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_UTILS_FPUTIL_SQRT_LONG_DOUBLE_X86_H
+#define LLVM_LIBC_UTILS_FPUTIL_SQRT_LONG_DOUBLE_X86_H
+
+#include "FPBits.h"
+#include "utils/CPP/TypeTraits.h"
+
+namespace __llvm_libc {
+namespace fputil {
+
+#if (defined(__x86_64__) || defined(__i386__))
+namespace internal {
+
+template <>
+inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
+  // Use binary search to shift the leading 1 bit similar to float.
+  // With MantissaWidth<long double> = 63, it will take
+  // ceil(log2(63)) = 6 steps checking the mantissa bits.
+  constexpr int nsteps = 6; // = ceil(log2(MantissaWidth))
+  constexpr __uint128_t bounds[nsteps] = {
+      __uint128_t(1) << 32, __uint128_t(1) << 48, __uint128_t(1) << 56,
+      __uint128_t(1) << 60, __uint128_t(1) << 62, __uint128_t(1) << 63};
+  constexpr int shifts[nsteps] = {32, 16, 8, 4, 2, 1};
+
+  for (int i = 0; i < nsteps; ++i) {
+    if (mantissa < bounds[i]) {
+      exponent -= shifts[i];
+      mantissa <<= shifts[i];
+    }
+  }
+}
+
+} // namespace internal
+
+// Correctly rounded SQRT with round to nearest, ties to even.
+// Shift-and-add algorithm.
+template <> inline long double sqrt<long double, 0>(long double x) {
+  using UIntType = typename FPBits<long double>::UIntType;
+  constexpr UIntType One = UIntType(1)
+                           << int(MantissaWidth<long double>::value);
+
+  FPBits<long double> bits(x);
+
+  if (bits.isInfOrNaN()) {
+    if (bits.sign && (bits.mantissa == 0)) {
+      // sqrt(-Inf) = NaN
+      return FPBits<long double>::buildNaN(One >> 1);
+    } else {
+      // sqrt(NaN) = NaN
+      // sqrt(+Inf) = +Inf
+      return x;
+    }
+  } else if (bits.isZero()) {
+    // sqrt(+0) = +0
+    // sqrt(-0) = -0
+    return x;
+  } else if (bits.sign) {
+    // sqrt( negative numbers ) = NaN
+    return FPBits<long double>::buildNaN(One >> 1);
+  } else {
+    int xExp = bits.getExponent();
+    UIntType xMant = bits.mantissa;
+
+    // Step 1a: Normalize denormal input
+    if (bits.implicitBit) {
+      xMant |= One;
+    } else if (bits.exponent == 0) {
+      internal::normalize<long double>(xExp, xMant);
+    }
+
+    // Step 1b: Make sure the exponent is even.
+    if (xExp & 1) {
+      --xExp;
+      xMant <<= 1;
+    }
+
+    // After step 1b, x = 2^(xExp) * xMant, where xExp is even, and
+    // 1 <= xMant < 4.  So sqrt(x) = 2^(xExp / 2) * y, with 1 <= y < 2.
+    // Notice that the output of sqrt is always in the normal range.
+    // To perform shift-and-add algorithm to find y, let denote:
+    //   y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
+    //   r(n) = 2^n ( xMant - y(n)^2 ).
+    // That leads to the following recurrence formula:
+    //   r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
+    // with the initial conditions: y(0) = 1, and r(0) = x - 1.
+    // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
+    //   y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
+    //         0 otherwise.
+    UIntType y = One;
+    UIntType r = xMant - One;
+
+    for (UIntType current_bit = One >> 1; current_bit; current_bit >>= 1) {
+      r <<= 1;
+      UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
+      if (r >= tmp) {
+        r -= tmp;
+        y += current_bit;
+      }
+    }
+
+    // We compute one more iteration in order to round correctly.
+    bool lsb = y & 1; // Least significant bit
+    bool rb = false;  // Round bit
+    r <<= 2;
+    UIntType tmp = (y << 2) + 1;
+    if (r >= tmp) {
+      r -= tmp;
+      rb = true;
+    }
+
+    // Append the exponent field.
+    xExp = ((xExp >> 1) + FPBits<long double>::exponentBias);
+    y |= (static_cast<UIntType>(xExp)
+          << (MantissaWidth<long double>::value + 1));
+
+    // Round to nearest, ties to even
+    if (rb && (lsb || (r != 0))) {
+      ++y;
+    }
+
+    // Extract output
+    FPBits<long double> out(0.0L);
+    out.exponent = xExp;
+    out.implicitBit = 1;
+    out.mantissa = (y & (One - 1));
+
+    return out;
+  }
+}
+#endif // defined(__x86_64__) || defined(__i386__)
+
+} // namespace fputil
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_UTILS_FPUTIL_SQRT_LONG_DOUBLE_X86_H


        


More information about the libc-commits mailing list