[libc-commits] [libc] 5078825 - [libc] Add implementations for sqrt, sqrtf, and sqrtl.
Tue Ly via libc-commits
libc-commits at lists.llvm.org
Wed Aug 26 06:46:49 PDT 2020
Author: Tue Ly
Date: 2020-08-26T09:46:18-04:00
New Revision: 5078825aa982905088502f14b5387fc5c96017fe
URL: https://github.com/llvm/llvm-project/commit/5078825aa982905088502f14b5387fc5c96017fe
DIFF: https://github.com/llvm/llvm-project/commit/5078825aa982905088502f14b5387fc5c96017fe.diff
LOG: [libc] Add implementations for sqrt, sqrtf, and sqrtl.
Differential Revision: https://reviews.llvm.org/D84726
Added:
libc/src/math/sqrt.cpp
libc/src/math/sqrt.h
libc/src/math/sqrtf.cpp
libc/src/math/sqrtf.h
libc/src/math/sqrtl.cpp
libc/src/math/sqrtl.h
libc/test/src/math/sqrt_test.cpp
libc/test/src/math/sqrtf_test.cpp
libc/test/src/math/sqrtl_test.cpp
libc/utils/FPUtil/Sqrt.h
libc/utils/FPUtil/SqrtLongDoubleX86.h
Modified:
libc/config/linux/aarch64/entrypoints.txt
libc/config/linux/api.td
libc/config/linux/x86_64/entrypoints.txt
libc/spec/stdc.td
libc/src/math/CMakeLists.txt
libc/test/src/math/CMakeLists.txt
Removed:
################################################################################
diff --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt
index fe63403ae221..34d07c24505d 100644
--- a/libc/config/linux/aarch64/entrypoints.txt
+++ b/libc/config/linux/aarch64/entrypoints.txt
@@ -75,6 +75,9 @@ set(TARGET_LIBM_ENTRYPOINTS
libc.src.math.roundl
libc.src.math.sincosf
libc.src.math.sinf
+ libc.src.math.sqrt
+ libc.src.math.sqrtf
+ libc.src.math.sqrtl
libc.src.math.trunc
libc.src.math.truncf
libc.src.math.truncl
diff --git a/libc/config/linux/api.td b/libc/config/linux/api.td
index 6b50c4284ae2..063fe401da8b 100644
--- a/libc/config/linux/api.td
+++ b/libc/config/linux/api.td
@@ -204,6 +204,9 @@ def MathAPI : PublicAPI<"math.h"> {
"roundl",
"sincosf",
"sinf",
+ "sqrt",
+ "sqrtf",
+ "sqrtl",
"trunc",
"truncf",
"truncl",
diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt
index 35ca8bbeaf68..c24173b1d0e7 100644
--- a/libc/config/linux/x86_64/entrypoints.txt
+++ b/libc/config/linux/x86_64/entrypoints.txt
@@ -108,6 +108,9 @@ set(TARGET_LIBM_ENTRYPOINTS
libc.src.math.roundl
libc.src.math.sincosf
libc.src.math.sinf
+ libc.src.math.sqrt
+ libc.src.math.sqrtf
+ libc.src.math.sqrtl
libc.src.math.trunc
libc.src.math.truncf
libc.src.math.truncl
diff --git a/libc/spec/stdc.td b/libc/spec/stdc.td
index ac240ff9576e..15fc12d375e6 100644
--- a/libc/spec/stdc.td
+++ b/libc/spec/stdc.td
@@ -314,6 +314,10 @@ def StdC : StandardSpec<"stdc"> {
FunctionSpec<"roundf", RetValSpec<FloatType>, [ArgSpec<FloatType>]>,
FunctionSpec<"roundl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>]>,
+ FunctionSpec<"sqrt", RetValSpec<DoubleType>, [ArgSpec<DoubleType>]>,
+ FunctionSpec<"sqrtf", RetValSpec<FloatType>, [ArgSpec<FloatType>]>,
+ FunctionSpec<"sqrtl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>]>,
+
FunctionSpec<"trunc", RetValSpec<DoubleType>, [ArgSpec<DoubleType>]>,
FunctionSpec<"truncf", RetValSpec<FloatType>, [ArgSpec<FloatType>]>,
FunctionSpec<"truncl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>]>,
diff --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt
index da18aeba9a2a..0c878de2ac95 100644
--- a/libc/src/math/CMakeLists.txt
+++ b/libc/src/math/CMakeLists.txt
@@ -485,3 +485,39 @@ add_entrypoint_object(
COMPILE_OPTIONS
-O2
)
+
+add_entrypoint_object(
+ sqrt
+ SRCS
+ sqrt.cpp
+ HDRS
+ sqrt.h
+ DEPENDS
+ libc.utils.FPUtil.fputil
+ COMPILE_OPTIONS
+ -O2
+)
+
+add_entrypoint_object(
+ sqrtf
+ SRCS
+ sqrtf.cpp
+ HDRS
+ sqrtf.h
+ DEPENDS
+ libc.utils.FPUtil.fputil
+ COMPILE_OPTIONS
+ -O2
+)
+
+add_entrypoint_object(
+ sqrtl
+ SRCS
+ sqrtl.cpp
+ HDRS
+ sqrtl.h
+ DEPENDS
+ libc.utils.FPUtil.fputil
+ COMPILE_OPTIONS
+ -O2
+)
diff --git a/libc/src/math/sqrt.cpp b/libc/src/math/sqrt.cpp
new file mode 100644
index 000000000000..32d38e61463d
--- /dev/null
+++ b/libc/src/math/sqrt.cpp
@@ -0,0 +1,16 @@
+//===-- Implementation of sqrt function -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "utils/FPUtil/Sqrt.h"
+#include "src/__support/common.h"
+
+namespace __llvm_libc {
+
+double LLVM_LIBC_ENTRYPOINT(sqrt)(double x) { return fputil::sqrt(x); }
+
+} // namespace __llvm_libc
diff --git a/libc/src/math/sqrt.h b/libc/src/math/sqrt.h
new file mode 100644
index 000000000000..2390e07b5dce
--- /dev/null
+++ b/libc/src/math/sqrt.h
@@ -0,0 +1,18 @@
+//===-- Implementation header for sqrt --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_SQRT_H
+#define LLVM_LIBC_SRC_MATH_SQRT_H
+
+namespace __llvm_libc {
+
+double sqrt(double x);
+
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_SRC_MATH_SQRT_H
diff --git a/libc/src/math/sqrtf.cpp b/libc/src/math/sqrtf.cpp
new file mode 100644
index 000000000000..391fa6a3281a
--- /dev/null
+++ b/libc/src/math/sqrtf.cpp
@@ -0,0 +1,16 @@
+//===-- Implementation of sqrtf function ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/__support/common.h"
+#include "utils/FPUtil/Sqrt.h"
+
+namespace __llvm_libc {
+
+float LLVM_LIBC_ENTRYPOINT(sqrtf)(float x) { return fputil::sqrt(x); }
+
+} // namespace __llvm_libc
diff --git a/libc/src/math/sqrtf.h b/libc/src/math/sqrtf.h
new file mode 100644
index 000000000000..d1d06f3adfa8
--- /dev/null
+++ b/libc/src/math/sqrtf.h
@@ -0,0 +1,18 @@
+//===-- Implementation header for sqrtf -------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_SQRTF_H
+#define LLVM_LIBC_SRC_MATH_SQRTF_H
+
+namespace __llvm_libc {
+
+float sqrtf(float x);
+
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_SRC_MATH_SQRTF_H
diff --git a/libc/src/math/sqrtl.cpp b/libc/src/math/sqrtl.cpp
new file mode 100644
index 000000000000..16450349d23a
--- /dev/null
+++ b/libc/src/math/sqrtl.cpp
@@ -0,0 +1,18 @@
+//===-- Implementation of sqrtl function ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/__support/common.h"
+#include "utils/FPUtil/Sqrt.h"
+
+namespace __llvm_libc {
+
+long double LLVM_LIBC_ENTRYPOINT(sqrtl)(long double x) {
+ return fputil::sqrt(x);
+}
+
+} // namespace __llvm_libc
diff --git a/libc/src/math/sqrtl.h b/libc/src/math/sqrtl.h
new file mode 100644
index 000000000000..5fbfa1450714
--- /dev/null
+++ b/libc/src/math/sqrtl.h
@@ -0,0 +1,18 @@
+//===-- Implementation header for sqrtl -------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_SQRTL_H
+#define LLVM_LIBC_SRC_MATH_SQRTL_H
+
+namespace __llvm_libc {
+
+long double sqrtl(long double x);
+
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_SRC_MATH_SQRTL_H
diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt
index 2fe766a2ffc6..07b505207452 100644
--- a/libc/test/src/math/CMakeLists.txt
+++ b/libc/test/src/math/CMakeLists.txt
@@ -513,3 +513,42 @@ add_fp_unittest(
libc.src.math.fmaxl
libc.utils.FPUtil.fputil
)
+
+add_fp_unittest(
+ sqrtf_test
+ NEED_MPFR
+ SUITE
+ libc_math_unittests
+ SRCS
+ sqrtf_test.cpp
+ DEPENDS
+ libc.include.math
+ libc.src.math.sqrtf
+ libc.utils.FPUtil.fputil
+)
+
+add_fp_unittest(
+ sqrt_test
+ NEED_MPFR
+ SUITE
+ libc_math_unittests
+ SRCS
+ sqrt_test.cpp
+ DEPENDS
+ libc.include.math
+ libc.src.math.sqrt
+ libc.utils.FPUtil.fputil
+)
+
+add_fp_unittest(
+ sqrtl_test
+ NEED_MPFR
+ SUITE
+ libc_math_unittests
+ SRCS
+ sqrtl_test.cpp
+ DEPENDS
+ libc.include.math
+ libc.src.math.sqrtl
+ libc.utils.FPUtil.fputil
+)
diff --git a/libc/test/src/math/sqrt_test.cpp b/libc/test/src/math/sqrt_test.cpp
new file mode 100644
index 000000000000..7ff4978ec9e3
--- /dev/null
+++ b/libc/test/src/math/sqrt_test.cpp
@@ -0,0 +1,67 @@
+//===-- Unittests for sqrt -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#include "include/math.h"
+#include "src/math/sqrt.h"
+#include "utils/FPUtil/FPBits.h"
+#include "utils/FPUtil/TestHelpers.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+using FPBits = __llvm_libc::fputil::FPBits<double>;
+using UIntType = typename FPBits::UIntType;
+
+namespace mpfr = __llvm_libc::testing::mpfr;
+
+constexpr UIntType HiddenBit =
+ UIntType(1) << __llvm_libc::fputil::MantissaWidth<double>::value;
+
+double nan = FPBits::buildNaN(1);
+double inf = FPBits::inf();
+double negInf = FPBits::negInf();
+
+TEST(SqrtTest, SpecialValues) {
+ ASSERT_FP_EQ(nan, __llvm_libc::sqrt(nan));
+ ASSERT_FP_EQ(inf, __llvm_libc::sqrt(inf));
+ ASSERT_FP_EQ(nan, __llvm_libc::sqrt(negInf));
+ ASSERT_FP_EQ(0.0, __llvm_libc::sqrt(0.0));
+ ASSERT_FP_EQ(-0.0, __llvm_libc::sqrt(-0.0));
+ ASSERT_FP_EQ(nan, __llvm_libc::sqrt(-1.0));
+ ASSERT_FP_EQ(1.0, __llvm_libc::sqrt(1.0));
+ ASSERT_FP_EQ(2.0, __llvm_libc::sqrt(4.0));
+ ASSERT_FP_EQ(3.0, __llvm_libc::sqrt(9.0));
+}
+
+TEST(SqrtTest, DenormalValues) {
+ for (UIntType mant = 1; mant < HiddenBit; mant <<= 1) {
+ FPBits denormal(0.0);
+ denormal.mantissa = mant;
+
+ ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, double(denormal),
+ __llvm_libc::sqrt(denormal), 0.5);
+ }
+
+ constexpr UIntType count = 1'000'001;
+ constexpr UIntType step = HiddenBit / count;
+ for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+ double x = *reinterpret_cast<double *>(&v);
+ ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrt(x), 0.5);
+ }
+}
+
+TEST(SqrtTest, InDoubleRange) {
+ constexpr UIntType count = 10'000'001;
+ constexpr UIntType step = UIntType(-1) / count;
+ for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+ double x = *reinterpret_cast<double *>(&v);
+ if (isnan(x) || (x < 0)) {
+ continue;
+ }
+
+ ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrt(x), 0.5);
+ }
+}
diff --git a/libc/test/src/math/sqrtf_test.cpp b/libc/test/src/math/sqrtf_test.cpp
new file mode 100644
index 000000000000..8c429065bb45
--- /dev/null
+++ b/libc/test/src/math/sqrtf_test.cpp
@@ -0,0 +1,67 @@
+//===-- Unittests for sqrtf -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#include "include/math.h"
+#include "src/math/sqrtf.h"
+#include "utils/FPUtil/FPBits.h"
+#include "utils/FPUtil/TestHelpers.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+using FPBits = __llvm_libc::fputil::FPBits<float>;
+using UIntType = typename FPBits::UIntType;
+
+namespace mpfr = __llvm_libc::testing::mpfr;
+
+constexpr UIntType HiddenBit =
+ UIntType(1) << __llvm_libc::fputil::MantissaWidth<float>::value;
+
+float nan = FPBits::buildNaN(1);
+float inf = FPBits::inf();
+float negInf = FPBits::negInf();
+
+TEST(SqrtfTest, SpecialValues) {
+ ASSERT_FP_EQ(nan, __llvm_libc::sqrtf(nan));
+ ASSERT_FP_EQ(inf, __llvm_libc::sqrtf(inf));
+ ASSERT_FP_EQ(nan, __llvm_libc::sqrtf(negInf));
+ ASSERT_FP_EQ(0.0f, __llvm_libc::sqrtf(0.0f));
+ ASSERT_FP_EQ(-0.0f, __llvm_libc::sqrtf(-0.0f));
+ ASSERT_FP_EQ(nan, __llvm_libc::sqrtf(-1.0f));
+ ASSERT_FP_EQ(1.0f, __llvm_libc::sqrtf(1.0f));
+ ASSERT_FP_EQ(2.0f, __llvm_libc::sqrtf(4.0f));
+ ASSERT_FP_EQ(3.0f, __llvm_libc::sqrtf(9.0f));
+}
+
+TEST(SqrtfTest, DenormalValues) {
+ for (UIntType mant = 1; mant < HiddenBit; mant <<= 1) {
+ FPBits denormal(0.0f);
+ denormal.mantissa = mant;
+
+ ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, float(denormal),
+ __llvm_libc::sqrtf(denormal), 0.5);
+ }
+
+ constexpr UIntType count = 1'000'001;
+ constexpr UIntType step = HiddenBit / count;
+ for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+ float x = *reinterpret_cast<float *>(&v);
+ ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtf(x), 0.5);
+ }
+}
+
+TEST(SqrtfTest, InFloatRange) {
+ constexpr UIntType count = 10'000'001;
+ constexpr UIntType step = UIntType(-1) / count;
+ for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+ float x = *reinterpret_cast<float *>(&v);
+ if (isnan(x) || (x < 0)) {
+ continue;
+ }
+
+ ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtf(x), 0.5);
+ }
+}
diff --git a/libc/test/src/math/sqrtl_test.cpp b/libc/test/src/math/sqrtl_test.cpp
new file mode 100644
index 000000000000..1fab3b2567e5
--- /dev/null
+++ b/libc/test/src/math/sqrtl_test.cpp
@@ -0,0 +1,67 @@
+//===-- Unittests for sqrtl ----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#include "include/math.h"
+#include "src/math/sqrtl.h"
+#include "utils/FPUtil/FPBits.h"
+#include "utils/FPUtil/TestHelpers.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+using FPBits = __llvm_libc::fputil::FPBits<long double>;
+using UIntType = typename FPBits::UIntType;
+
+namespace mpfr = __llvm_libc::testing::mpfr;
+
+constexpr UIntType HiddenBit =
+ UIntType(1) << __llvm_libc::fputil::MantissaWidth<long double>::value;
+
+long double nan = FPBits::buildNaN(1);
+long double inf = FPBits::inf();
+long double negInf = FPBits::negInf();
+
+TEST(SqrtlTest, SpecialValues) {
+ ASSERT_FP_EQ(nan, __llvm_libc::sqrtl(nan));
+ ASSERT_FP_EQ(inf, __llvm_libc::sqrtl(inf));
+ ASSERT_FP_EQ(nan, __llvm_libc::sqrtl(negInf));
+ ASSERT_FP_EQ(0.0L, __llvm_libc::sqrtl(0.0L));
+ ASSERT_FP_EQ(-0.0L, __llvm_libc::sqrtl(-0.0L));
+ ASSERT_FP_EQ(nan, __llvm_libc::sqrtl(-1.0L));
+ ASSERT_FP_EQ(1.0L, __llvm_libc::sqrtl(1.0L));
+ ASSERT_FP_EQ(2.0L, __llvm_libc::sqrtl(4.0L));
+ ASSERT_FP_EQ(3.0L, __llvm_libc::sqrtl(9.0L));
+}
+
+TEST(SqrtlTest, DenormalValues) {
+ for (UIntType mant = 1; mant < HiddenBit; mant <<= 1) {
+ FPBits denormal(0.0L);
+ denormal.mantissa = mant;
+
+ ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, static_cast<long double>(denormal),
+ __llvm_libc::sqrtl(denormal), 0.5);
+ }
+
+ constexpr UIntType count = 1'000'001;
+ constexpr UIntType step = HiddenBit / count;
+ for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+ long double x = *reinterpret_cast<long double *>(&v);
+ ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtl(x), 0.5);
+ }
+}
+
+TEST(SqrtlTest, InLongDoubleRange) {
+ constexpr UIntType count = 10'000'001;
+ constexpr UIntType step = UIntType(-1) / count;
+ for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+ long double x = *reinterpret_cast<long double *>(&v);
+ if (isnan(x) || (x < 0)) {
+ continue;
+ }
+
+ ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtl(x), 0.5);
+ }
+}
diff --git a/libc/utils/FPUtil/Sqrt.h b/libc/utils/FPUtil/Sqrt.h
new file mode 100644
index 000000000000..a12cc42fa340
--- /dev/null
+++ b/libc/utils/FPUtil/Sqrt.h
@@ -0,0 +1,186 @@
+//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_UTILS_FPUTIL_SQRT_H
+#define LLVM_LIBC_UTILS_FPUTIL_SQRT_H
+
+#include "FPBits.h"
+
+#include "utils/CPP/TypeTraits.h"
+
+namespace __llvm_libc {
+namespace fputil {
+
+namespace internal {
+
+template <typename T>
+static inline void normalize(int &exponent,
+ typename FPBits<T>::UIntType &mantissa);
+
+template <> inline void normalize<float>(int &exponent, uint32_t &mantissa) {
+ // Use binary search to shift the leading 1 bit.
+ // With MantissaWidth<float> = 23, it will take
+ // ceil(log2(23)) = 5 steps checking the mantissa bits as followed:
+ // Step 1: 0000 0000 0000 XXXX XXXX XXXX
+ // Step 2: 0000 00XX XXXX XXXX XXXX XXXX
+ // Step 3: 000X XXXX XXXX XXXX XXXX XXXX
+ // Step 4: 00XX XXXX XXXX XXXX XXXX XXXX
+ // Step 5: 0XXX XXXX XXXX XXXX XXXX XXXX
+ constexpr int nsteps = 5; // = ceil(log2(MantissaWidth))
+ constexpr uint32_t bounds[nsteps] = {1 << 12, 1 << 18, 1 << 21, 1 << 22,
+ 1 << 23};
+ constexpr int shifts[nsteps] = {12, 6, 3, 2, 1};
+
+ for (int i = 0; i < nsteps; ++i) {
+ if (mantissa < bounds[i]) {
+ exponent -= shifts[i];
+ mantissa <<= shifts[i];
+ }
+ }
+}
+
+template <> inline void normalize<double>(int &exponent, uint64_t &mantissa) {
+ // Use binary search to shift the leading 1 bit similar to float.
+ // With MantissaWidth<double> = 52, it will take
+ // ceil(log2(52)) = 6 steps checking the mantissa bits.
+ constexpr int nsteps = 6; // = ceil(log2(MantissaWidth))
+ constexpr uint64_t bounds[nsteps] = {1ULL << 26, 1ULL << 39, 1ULL << 46,
+ 1ULL << 49, 1ULL << 51, 1ULL << 52};
+ constexpr int shifts[nsteps] = {27, 14, 7, 4, 2, 1};
+
+ for (int i = 0; i < nsteps; ++i) {
+ if (mantissa < bounds[i]) {
+ exponent -= shifts[i];
+ mantissa <<= shifts[i];
+ }
+ }
+}
+
+#if !(defined(__x86_64__) || defined(__i386__))
+template <>
+inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
+ // Use binary search to shift the leading 1 bit similar to float.
+ // With MantissaWidth<long double> = 112, it will take
+ // ceil(log2(112)) = 7 steps checking the mantissa bits.
+ constexpr int nsteps = 7; // = ceil(log2(MantissaWidth))
+ constexpr __uint128_t bounds[nsteps] = {
+ __uint128_t(1) << 56, __uint128_t(1) << 84, __uint128_t(1) << 98,
+ __uint128_t(1) << 105, __uint128_t(1) << 109, __uint128_t(1) << 111,
+ __uint128_t(1) << 112};
+ constexpr int shifts[nsteps] = {57, 29, 15, 8, 4, 2, 1};
+
+ for (int i = 0; i < nsteps; ++i) {
+ if (mantissa < bounds[i]) {
+ exponent -= shifts[i];
+ mantissa <<= shifts[i];
+ }
+ }
+}
+#endif
+
+} // namespace internal
+
+// Correctly rounded IEEE 754 SQRT with round to nearest, ties to even.
+// Shift-and-add algorithm.
+template <typename T,
+ cpp::EnableIfType<cpp::IsFloatingPointType<T>::Value, int> = 0>
+static inline T sqrt(T x) {
+ using UIntType = typename FPBits<T>::UIntType;
+ constexpr UIntType One = UIntType(1) << MantissaWidth<T>::value;
+
+ FPBits<T> bits(x);
+
+ if (bits.isInfOrNaN()) {
+ if (bits.sign && (bits.mantissa == 0)) {
+ // sqrt(-Inf) = NaN
+ return FPBits<T>::buildNaN(One >> 1);
+ } else {
+ // sqrt(NaN) = NaN
+ // sqrt(+Inf) = +Inf
+ return x;
+ }
+ } else if (bits.isZero()) {
+ // sqrt(+0) = +0
+ // sqrt(-0) = -0
+ return x;
+ } else if (bits.sign) {
+ // sqrt( negative numbers ) = NaN
+ return FPBits<T>::buildNaN(One >> 1);
+ } else {
+ int xExp = bits.getExponent();
+ UIntType xMant = bits.mantissa;
+
+ // Step 1a: Normalize denormal input and append hiddent bit to the mantissa
+ if (bits.exponent == 0) {
+ ++xExp; // let xExp be the correct exponent of One bit.
+ internal::normalize<T>(xExp, xMant);
+ } else {
+ xMant |= One;
+ }
+
+ // Step 1b: Make sure the exponent is even.
+ if (xExp & 1) {
+ --xExp;
+ xMant <<= 1;
+ }
+
+ // After step 1b, x = 2^(xExp) * xMant, where xExp is even, and
+ // 1 <= xMant < 4. So sqrt(x) = 2^(xExp / 2) * y, with 1 <= y < 2.
+ // Notice that the output of sqrt is always in the normal range.
+ // To perform shift-and-add algorithm to find y, let denote:
+ // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
+ // r(n) = 2^n ( xMant - y(n)^2 ).
+ // That leads to the following recurrence formula:
+ // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
+ // with the initial conditions: y(0) = 1, and r(0) = x - 1.
+ // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
+ // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
+ // 0 otherwise.
+ UIntType y = One;
+ UIntType r = xMant - One;
+
+ for (UIntType current_bit = One >> 1; current_bit; current_bit >>= 1) {
+ r <<= 1;
+ UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
+ if (r >= tmp) {
+ r -= tmp;
+ y += current_bit;
+ }
+ }
+
+ // We compute one more iteration in order to round correctly.
+ bool lsb = y & 1; // Least significant bit
+ bool rb = false; // Round bit
+ r <<= 2;
+ UIntType tmp = (y << 2) + 1;
+ if (r >= tmp) {
+ r -= tmp;
+ rb = true;
+ }
+
+ // Remove hidden bit and append the exponent field.
+ xExp = ((xExp >> 1) + FPBits<T>::exponentBias);
+
+ y = (y - One) | (static_cast<UIntType>(xExp) << MantissaWidth<T>::value);
+ // Round to nearest, ties to even
+ if (rb && (lsb || (r != 0))) {
+ ++y;
+ }
+
+ return *reinterpret_cast<T *>(&y);
+ }
+}
+
+} // namespace fputil
+} // namespace __llvm_libc
+
+#if (defined(__x86_64__) || defined(__i386__))
+#include "SqrtLongDoubleX86.h"
+#endif // defined(__x86_64__) || defined(__i386__)
+
+#endif // LLVM_LIBC_UTILS_FPUTIL_SQRT_H
diff --git a/libc/utils/FPUtil/SqrtLongDoubleX86.h b/libc/utils/FPUtil/SqrtLongDoubleX86.h
new file mode 100644
index 000000000000..2ac73044cf92
--- /dev/null
+++ b/libc/utils/FPUtil/SqrtLongDoubleX86.h
@@ -0,0 +1,142 @@
+//===-- Square root of x86 long double numbers ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_UTILS_FPUTIL_SQRT_LONG_DOUBLE_X86_H
+#define LLVM_LIBC_UTILS_FPUTIL_SQRT_LONG_DOUBLE_X86_H
+
+#include "FPBits.h"
+#include "utils/CPP/TypeTraits.h"
+
+namespace __llvm_libc {
+namespace fputil {
+
+#if (defined(__x86_64__) || defined(__i386__))
+namespace internal {
+
+template <>
+inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
+ // Use binary search to shift the leading 1 bit similar to float.
+ // With MantissaWidth<long double> = 63, it will take
+ // ceil(log2(63)) = 6 steps checking the mantissa bits.
+ constexpr int nsteps = 6; // = ceil(log2(MantissaWidth))
+ constexpr __uint128_t bounds[nsteps] = {
+ __uint128_t(1) << 32, __uint128_t(1) << 48, __uint128_t(1) << 56,
+ __uint128_t(1) << 60, __uint128_t(1) << 62, __uint128_t(1) << 63};
+ constexpr int shifts[nsteps] = {32, 16, 8, 4, 2, 1};
+
+ for (int i = 0; i < nsteps; ++i) {
+ if (mantissa < bounds[i]) {
+ exponent -= shifts[i];
+ mantissa <<= shifts[i];
+ }
+ }
+}
+
+} // namespace internal
+
+// Correctly rounded SQRT with round to nearest, ties to even.
+// Shift-and-add algorithm.
+template <> inline long double sqrt<long double, 0>(long double x) {
+ using UIntType = typename FPBits<long double>::UIntType;
+ constexpr UIntType One = UIntType(1)
+ << int(MantissaWidth<long double>::value);
+
+ FPBits<long double> bits(x);
+
+ if (bits.isInfOrNaN()) {
+ if (bits.sign && (bits.mantissa == 0)) {
+ // sqrt(-Inf) = NaN
+ return FPBits<long double>::buildNaN(One >> 1);
+ } else {
+ // sqrt(NaN) = NaN
+ // sqrt(+Inf) = +Inf
+ return x;
+ }
+ } else if (bits.isZero()) {
+ // sqrt(+0) = +0
+ // sqrt(-0) = -0
+ return x;
+ } else if (bits.sign) {
+ // sqrt( negative numbers ) = NaN
+ return FPBits<long double>::buildNaN(One >> 1);
+ } else {
+ int xExp = bits.getExponent();
+ UIntType xMant = bits.mantissa;
+
+ // Step 1a: Normalize denormal input
+ if (bits.implicitBit) {
+ xMant |= One;
+ } else if (bits.exponent == 0) {
+ internal::normalize<long double>(xExp, xMant);
+ }
+
+ // Step 1b: Make sure the exponent is even.
+ if (xExp & 1) {
+ --xExp;
+ xMant <<= 1;
+ }
+
+ // After step 1b, x = 2^(xExp) * xMant, where xExp is even, and
+ // 1 <= xMant < 4. So sqrt(x) = 2^(xExp / 2) * y, with 1 <= y < 2.
+ // Notice that the output of sqrt is always in the normal range.
+ // To perform shift-and-add algorithm to find y, let denote:
+ // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
+ // r(n) = 2^n ( xMant - y(n)^2 ).
+ // That leads to the following recurrence formula:
+ // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
+ // with the initial conditions: y(0) = 1, and r(0) = x - 1.
+ // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
+ // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
+ // 0 otherwise.
+ UIntType y = One;
+ UIntType r = xMant - One;
+
+ for (UIntType current_bit = One >> 1; current_bit; current_bit >>= 1) {
+ r <<= 1;
+ UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
+ if (r >= tmp) {
+ r -= tmp;
+ y += current_bit;
+ }
+ }
+
+ // We compute one more iteration in order to round correctly.
+ bool lsb = y & 1; // Least significant bit
+ bool rb = false; // Round bit
+ r <<= 2;
+ UIntType tmp = (y << 2) + 1;
+ if (r >= tmp) {
+ r -= tmp;
+ rb = true;
+ }
+
+ // Append the exponent field.
+ xExp = ((xExp >> 1) + FPBits<long double>::exponentBias);
+ y |= (static_cast<UIntType>(xExp)
+ << (MantissaWidth<long double>::value + 1));
+
+ // Round to nearest, ties to even
+ if (rb && (lsb || (r != 0))) {
+ ++y;
+ }
+
+ // Extract output
+ FPBits<long double> out(0.0L);
+ out.exponent = xExp;
+ out.implicitBit = 1;
+ out.mantissa = (y & (One - 1));
+
+ return out;
+ }
+}
+#endif // defined(__x86_64__) || defined(__i386__)
+
+} // namespace fputil
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_UTILS_FPUTIL_SQRT_LONG_DOUBLE_X86_H
More information about the libc-commits
mailing list