[libc-commits] [libc] 7e7ecef - [libc] Replace type punning with bit_cast

Guillaume Chatelet via libc-commits libc-commits at lists.llvm.org
Tue Feb 8 12:46:11 PST 2022


Author: Guillaume Chatelet
Date: 2022-02-08T20:45:59Z
New Revision: 7e7ecef98080500ad12766695a86d687170d3639

URL: https://github.com/llvm/llvm-project/commit/7e7ecef98080500ad12766695a86d687170d3639
DIFF: https://github.com/llvm/llvm-project/commit/7e7ecef98080500ad12766695a86d687170d3639.diff

LOG: [libc] Replace type punning with bit_cast

Although type punning is defined for union in C, it is UB in C++.
This patch introduces a bit_cast function to convert between types in a safe way.

This is necessary to get llvm-libc compile with GCC.
This patch is extracted from D119002.

Differential Revision: https://reviews.llvm.org/D119145

Added: 
    libc/src/__support/CPP/Bit.h

Modified: 
    libc/src/__support/CPP/CMakeLists.txt
    libc/src/__support/FPUtil/FPBits.h
    libc/src/__support/FPUtil/Hypot.h
    libc/src/__support/FPUtil/ManipulationFunctions.h
    libc/src/__support/FPUtil/generic/sqrt.h
    libc/src/__support/FPUtil/x86_64/LongDoubleBits.h
    libc/src/__support/FPUtil/x86_64/NextAfterLongDouble.h
    libc/src/__support/FPUtil/x86_64/sqrt.h
    libc/src/math/generic/log10f.cpp
    libc/src/math/generic/log1pf.cpp
    libc/src/math/generic/log2f.cpp
    libc/src/math/generic/logf.cpp
    libc/src/math/generic/math_utils.h
    libc/src/string/memory_utils/CMakeLists.txt
    libc/src/string/memory_utils/elements_x86.h
    libc/test/src/math/NextAfterTest.h
    libc/test/src/math/SqrtTest.h
    utils/bazel/llvm-project-overlay/libc/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/libc/src/__support/CPP/Bit.h b/libc/src/__support/CPP/Bit.h
new file mode 100644
index 0000000000000..517762788757d
--- /dev/null
+++ b/libc/src/__support/CPP/Bit.h
@@ -0,0 +1,48 @@
+//===-- Freestanding version of bit_cast  -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SUPPORT_CPP_BIT_H
+#define LLVM_LIBC_SUPPORT_CPP_BIT_H
+
+namespace __llvm_libc {
+
+#if defined __has_builtin
+#if __has_builtin(__builtin_bit_cast)
+#define LLVM_LIBC_HAS_BUILTIN_BIT_CAST
+#endif
+#endif
+
+#if defined __has_builtin
+#if __has_builtin(__builtin_memcpy_inline)
+#define LLVM_LIBC_HAS_BUILTIN_MEMCPY_INLINE
+#endif
+#endif
+
+// This function guarantees the bitcast to be optimized away by the compiler for
+// GCC >= 8 and Clang >= 6.
+template <class To, class From> constexpr To bit_cast(const From &from) {
+  static_assert(sizeof(To) == sizeof(From), "To and From must be of same size");
+#if defined(LLVM_LIBC_HAS_BUILTIN_BIT_CAST)
+  return __builtin_bit_cast(To, from);
+#else
+  To to;
+  char *dst = reinterpret_cast<char *>(&to);
+  const char *src = reinterpret_cast<const char *>(&from);
+#if defined(LLVM_LIBC_HAS_BUILTIN_MEMCPY_INLINE)
+  __builtin_memcpy_inline(dst, src, sizeof(To));
+#else
+  for (unsigned i = 0; i < sizeof(To); ++i)
+    dst[i] = src[i];
+#endif // defined(LLVM_LIBC_HAS_BUILTIN_MEMCPY_INLINE)
+  return to;
+#endif // defined(LLVM_LIBC_HAS_BUILTIN_BIT_CAST)
+}
+
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_SUPPORT_CPP_BIT_H

diff  --git a/libc/src/__support/CPP/CMakeLists.txt b/libc/src/__support/CPP/CMakeLists.txt
index 72b3e786808a9..78f6cf2cd2c3f 100644
--- a/libc/src/__support/CPP/CMakeLists.txt
+++ b/libc/src/__support/CPP/CMakeLists.txt
@@ -3,9 +3,10 @@ add_header_library(
   HDRS
     Array.h
     ArrayRef.h
+    Bit.h
     Bitset.h
     Functional.h
+    Limits.h
     StringView.h
     TypeTraits.h
-    Limits.h
 )

diff  --git a/libc/src/__support/FPUtil/FPBits.h b/libc/src/__support/FPUtil/FPBits.h
index 19c1d1f1ac8f6..36df9c54c4777 100644
--- a/libc/src/__support/FPUtil/FPBits.h
+++ b/libc/src/__support/FPUtil/FPBits.h
@@ -11,6 +11,7 @@
 
 #include "PlatformDefs.h"
 
+#include "src/__support/CPP/Bit.h"
 #include "src/__support/CPP/TypeTraits.h"
 
 #include "FloatProperties.h"
@@ -35,7 +36,7 @@ template <typename T> struct ExponentWidth {
 // floating numbers. On x86 platforms however, the 'long double' type maps to
 // an x87 floating point format. This format is an IEEE 754 extension format.
 // It is handled as an explicit specialization of this class.
-template <typename T> union FPBits {
+template <typename T> struct FPBits {
   static_assert(cpp::IsFloatingPointType<T>::Value,
                 "FPBits instantiated with invalid type.");
 
@@ -76,7 +77,6 @@ template <typename T> union FPBits {
   bool get_sign() const {
     return ((bits & FloatProp::SIGN_MASK) >> (FloatProp::BIT_WIDTH - 1));
   }
-  T val;
 
   static_assert(sizeof(T) == sizeof(UIntType),
                 "Data type and integral representation have 
diff erent sizes.");
@@ -96,15 +96,20 @@ template <typename T> union FPBits {
   // type match.
   template <typename XType,
             cpp::EnableIfType<cpp::IsSame<T, XType>::Value, int> = 0>
-  explicit FPBits(XType x) : val(x) {}
+  constexpr explicit FPBits(XType x)
+      : bits(__llvm_libc::bit_cast<UIntType>(x)) {}
 
   template <typename XType,
             cpp::EnableIfType<cpp::IsSame<XType, UIntType>::Value, int> = 0>
-  explicit FPBits(XType x) : bits(x) {}
+  constexpr explicit FPBits(XType x) : bits(x) {}
 
   FPBits() : bits(0) {}
 
-  explicit operator T() { return val; }
+  T get_val() const { return __llvm_libc::bit_cast<T>(bits); }
+
+  void set_val(T value) { bits = __llvm_libc::bit_cast<UIntType>(value); }
+
+  explicit operator T() const { return get_val(); }
 
   UIntType uintval() const { return bits; }
 

diff  --git a/libc/src/__support/FPUtil/Hypot.h b/libc/src/__support/FPUtil/Hypot.h
index bb658b0085fea..5111bbee39992 100644
--- a/libc/src/__support/FPUtil/Hypot.h
+++ b/libc/src/__support/FPUtil/Hypot.h
@@ -12,6 +12,7 @@
 #include "BasicOperations.h"
 #include "FEnvImpl.h"
 #include "FPBits.h"
+#include "src/__support/CPP/Bit.h"
 #include "src/__support/CPP/TypeTraits.h"
 
 namespace __llvm_libc {
@@ -285,7 +286,7 @@ static inline T hypot(T x, T y) {
   }
 
   y_new |= static_cast<UIntType>(out_exp) << MantissaWidth<T>::VALUE;
-  return *reinterpret_cast<T *>(&y_new);
+  return __llvm_libc::bit_cast<T>(y_new);
 }
 
 } // namespace fputil

diff  --git a/libc/src/__support/FPUtil/ManipulationFunctions.h b/libc/src/__support/FPUtil/ManipulationFunctions.h
index 0c6f322e3c517..f922c0ca77b68 100644
--- a/libc/src/__support/FPUtil/ManipulationFunctions.h
+++ b/libc/src/__support/FPUtil/ManipulationFunctions.h
@@ -14,6 +14,7 @@
 #include "NormalFloat.h"
 #include "PlatformDefs.h"
 
+#include "src/__support/CPP/Bit.h"
 #include "src/__support/CPP/TypeTraits.h"
 
 #include <limits.h>
@@ -171,7 +172,7 @@ static inline T nextafter(T from, T to) {
     int_val = (to_bits.uintval() & sign_mask) + UIntType(1);
   }
 
-  return *reinterpret_cast<T *>(&int_val);
+  return __llvm_libc::bit_cast<T>(int_val);
   // TODO: Raise floating point exceptions as required by the standard.
 }
 

diff  --git a/libc/src/__support/FPUtil/generic/sqrt.h b/libc/src/__support/FPUtil/generic/sqrt.h
index 1882d3e82ecf4..ad85bca94af55 100644
--- a/libc/src/__support/FPUtil/generic/sqrt.h
+++ b/libc/src/__support/FPUtil/generic/sqrt.h
@@ -10,6 +10,7 @@
 #define LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H
 
 #include "sqrt_80_bit_long_double.h"
+#include "src/__support/CPP/Bit.h"
 #include "src/__support/CPP/TypeTraits.h"
 #include "src/__support/FPUtil/FEnvImpl.h"
 #include "src/__support/FPUtil/FPBits.h"
@@ -203,7 +204,7 @@ sqrt(T x) {
         break;
       }
 
-      return *reinterpret_cast<T *>(&y);
+      return __llvm_libc::bit_cast<T>(y);
     }
   }
 }

diff  --git a/libc/src/__support/FPUtil/x86_64/LongDoubleBits.h b/libc/src/__support/FPUtil/x86_64/LongDoubleBits.h
index 4fc00b1429779..0a2dc452fc73d 100644
--- a/libc/src/__support/FPUtil/x86_64/LongDoubleBits.h
+++ b/libc/src/__support/FPUtil/x86_64/LongDoubleBits.h
@@ -9,6 +9,7 @@
 #ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_LONG_DOUBLE_BITS_H
 #define LLVM_LIBC_SRC_SUPPORT_FPUTIL_X86_64_LONG_DOUBLE_BITS_H
 
+#include "src/__support/CPP/Bit.h"
 #include "src/__support/architectures.h"
 
 #if !defined(LLVM_LIBC_ARCH_X86)
@@ -30,7 +31,7 @@ template <> struct Padding<4> { static constexpr unsigned VALUE = 16; };
 // x86_64 padding.
 template <> struct Padding<8> { static constexpr unsigned VALUE = 48; };
 
-template <> union FPBits<long double> {
+template <> struct FPBits<long double> {
   using UIntType = __uint128_t;
 
   static constexpr int EXPONENT_BIAS = 0x3FFF;
@@ -91,13 +92,11 @@ template <> union FPBits<long double> {
     return ((bits & FloatProp::SIGN_MASK) >> (FloatProp::BIT_WIDTH - 1));
   }
 
-  long double val;
-
   FPBits() : bits(0) {}
 
   template <typename XType,
             cpp::EnableIfType<cpp::IsSame<long double, XType>::Value, int> = 0>
-  explicit FPBits(XType x) : val(x) {
+  explicit FPBits(XType x) : bits(__llvm_libc::bit_cast<UIntType>(x)) {
     // bits starts uninitialized, and setting it to a long double only
     // overwrites the first 80 bits. This clears those upper bits.
     bits = bits & ((UIntType(1) << 80) - 1);
@@ -107,7 +106,7 @@ template <> union FPBits<long double> {
             cpp::EnableIfType<cpp::IsSame<XType, UIntType>::Value, int> = 0>
   explicit FPBits(XType x) : bits(x) {}
 
-  operator long double() { return val; }
+  operator long double() { return __llvm_libc::bit_cast<long double>(bits); }
 
   UIntType uintval() {
     // We zero the padding bits as they can contain garbage.

diff  --git a/libc/src/__support/FPUtil/x86_64/NextAfterLongDouble.h b/libc/src/__support/FPUtil/x86_64/NextAfterLongDouble.h
index db5c946a4d830..9251d838315a1 100644
--- a/libc/src/__support/FPUtil/x86_64/NextAfterLongDouble.h
+++ b/libc/src/__support/FPUtil/x86_64/NextAfterLongDouble.h
@@ -15,6 +15,7 @@
 #error "Invalid include"
 #endif
 
+#include "src/__support/CPP/Bit.h"
 #include "src/__support/FPUtil/FPBits.h"
 
 #include <stdint.h>
@@ -111,7 +112,7 @@ static inline long double nextafter(long double from, long double to) {
     }
   }
 
-  return *reinterpret_cast<long double *>(&int_val);
+  return __llvm_libc::bit_cast<long double>(int_val);
   // TODO: Raise floating point exceptions as required by the standard.
 }
 

diff  --git a/libc/src/__support/FPUtil/x86_64/sqrt.h b/libc/src/__support/FPUtil/x86_64/sqrt.h
index 8a8f8cf2238db..e3c68206c5a63 100644
--- a/libc/src/__support/FPUtil/x86_64/sqrt.h
+++ b/libc/src/__support/FPUtil/x86_64/sqrt.h
@@ -33,9 +33,8 @@ template <> inline double sqrt<double>(double x) {
 }
 
 template <> inline long double sqrt<long double>(long double x) {
-  long double result;
-  __asm__ __volatile__("fsqrt" : "=t"(result) : "t"(x));
-  return result;
+  __asm__ __volatile__("fsqrt" : "+t"(x));
+  return x;
 }
 
 } // namespace fputil

diff  --git a/libc/src/math/generic/log10f.cpp b/libc/src/math/generic/log10f.cpp
index c7dbbfd494c7d..b7509b3fb2a60 100644
--- a/libc/src/math/generic/log10f.cpp
+++ b/libc/src/math/generic/log10f.cpp
@@ -155,7 +155,7 @@ LLVM_LIBC_FUNCTION(float, log10f, (float x)) {
       return x;
     }
     // Normalize denormal inputs.
-    xbits.val *= 0x1.0p23f;
+    xbits.set_val(xbits.get_val() * 0x1.0p23f);
     m -= 23.0;
   }
 
@@ -164,7 +164,7 @@ LLVM_LIBC_FUNCTION(float, log10f, (float x)) {
   xbits.set_unbiased_exponent(0x7F);
   int f_index = xbits.get_mantissa() >> 16;
 
-  FPBits f(xbits.val);
+  FPBits f = xbits;
   f.bits &= ~0x0000'FFFF;
 
   double d = static_cast<float>(xbits) - static_cast<float>(f);

diff  --git a/libc/src/math/generic/log1pf.cpp b/libc/src/math/generic/log1pf.cpp
index b9494a4e02080..a4d615c9dd22f 100644
--- a/libc/src/math/generic/log1pf.cpp
+++ b/libc/src/math/generic/log1pf.cpp
@@ -59,7 +59,7 @@ INLINE_FMA static inline float log(double x) {
   int f_index =
       xbits.get_mantissa() >> 45; // fputil::MantissaWidth<double>::VALUE - 7
 
-  FPBits f(xbits.val);
+  FPBits f = xbits;
   // Clear the lowest 45 bits.
   f.bits &= ~0x0000'1FFF'FFFF'FFFFULL;
 

diff  --git a/libc/src/math/generic/log2f.cpp b/libc/src/math/generic/log2f.cpp
index dc5a6b670afd5..3957db54f72dd 100644
--- a/libc/src/math/generic/log2f.cpp
+++ b/libc/src/math/generic/log2f.cpp
@@ -138,7 +138,7 @@ LLVM_LIBC_FUNCTION(float, log2f, (float x)) {
       return x;
     }
     // Normalize denormal inputs.
-    xbits.val *= 0x1.0p23f;
+    xbits.set_val(xbits.get_val() * 0x1.0p23f);
     m = -23;
   }
 
@@ -149,7 +149,7 @@ LLVM_LIBC_FUNCTION(float, log2f, (float x)) {
   // lookup tables.
   int f_index = xbits.get_mantissa() >> 16;
 
-  FPBits f(xbits.val);
+  FPBits f = xbits;
   // Clear the lowest 16 bits.
   f.bits &= ~0x0000'FFFF;
 

diff  --git a/libc/src/math/generic/logf.cpp b/libc/src/math/generic/logf.cpp
index 390617bf1ea7f..3e712378b64c3 100644
--- a/libc/src/math/generic/logf.cpp
+++ b/libc/src/math/generic/logf.cpp
@@ -104,7 +104,7 @@ LLVM_LIBC_FUNCTION(float, logf, (float x)) {
       return x;
     }
     // Normalize denormal inputs.
-    xbits.val *= 0x1.0p23f;
+    xbits.set_val(xbits.get_val() * 0x1.0p23f);
     m = -23;
   }
 
@@ -113,7 +113,7 @@ LLVM_LIBC_FUNCTION(float, logf, (float x)) {
   xbits.set_unbiased_exponent(0x7F);
   int f_index = xbits.get_mantissa() >> 16;
 
-  FPBits f(xbits.val);
+  FPBits f = xbits;
   f.bits &= ~0x0000'FFFF;
 
   double d = static_cast<float>(xbits) - static_cast<float>(f);

diff  --git a/libc/src/math/generic/math_utils.h b/libc/src/math/generic/math_utils.h
index 7dbb78500616f..e705f8c0f534b 100644
--- a/libc/src/math/generic/math_utils.h
+++ b/libc/src/math/generic/math_utils.h
@@ -9,6 +9,7 @@
 #ifndef LLVM_LIBC_SRC_MATH_MATH_UTILS_H
 #define LLVM_LIBC_SRC_MATH_MATH_UTILS_H
 
+#include "src/__support/CPP/Bit.h"
 #include "src/__support/CPP/TypeTraits.h"
 #include "src/__support/common.h"
 #include <errno.h>
@@ -19,19 +20,19 @@
 namespace __llvm_libc {
 
 static inline uint32_t as_uint32_bits(float x) {
-  return *reinterpret_cast<uint32_t *>(&x);
+  return __llvm_libc::bit_cast<uint32_t>(x);
 }
 
 static inline uint64_t as_uint64_bits(double x) {
-  return *reinterpret_cast<uint64_t *>(&x);
+  return __llvm_libc::bit_cast<uint64_t>(x);
 }
 
 static inline float as_float(uint32_t x) {
-  return *reinterpret_cast<float *>(&x);
+  return __llvm_libc::bit_cast<float>(x);
 }
 
 static inline double as_double(uint64_t x) {
-  return *reinterpret_cast<double *>(&x);
+  return __llvm_libc::bit_cast<double>(x);
 }
 
 static inline uint32_t top12_bits(float x) { return as_uint32_bits(x) >> 20; }

diff  --git a/libc/src/string/memory_utils/CMakeLists.txt b/libc/src/string/memory_utils/CMakeLists.txt
index 51bbff3d27bfa..ca5cfdbd8db33 100644
--- a/libc/src/string/memory_utils/CMakeLists.txt
+++ b/libc/src/string/memory_utils/CMakeLists.txt
@@ -7,6 +7,8 @@ add_header_library(
     memcmp_implementations.h
     memcpy_implementations.h
     memset_implementations.h
+  DEPS
+    standalone_cpp
 )
 
 add_header_library(

diff  --git a/libc/src/string/memory_utils/elements_x86.h b/libc/src/string/memory_utils/elements_x86.h
index e09e62689e8fc..89a88c6703d5f 100644
--- a/libc/src/string/memory_utils/elements_x86.h
+++ b/libc/src/string/memory_utils/elements_x86.h
@@ -9,6 +9,7 @@
 #ifndef LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_X86_H
 #define LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_X86_H
 
+#include "src/__support/CPP/Bit.h"
 #include "src/__support/architectures.h"
 
 #if defined(LLVM_LIBC_ARCH_X86)
@@ -66,16 +67,18 @@ struct M128 {
   using T = char __attribute__((__vector_size__(SIZE)));
   static uint16_t mask(T value) {
     // NOLINTNEXTLINE(llvmlibc-callee-namespace)
-    return _mm_movemask_epi8(value);
+    return _mm_movemask_epi8(__llvm_libc::bit_cast<__m128i>(value));
   }
   static uint16_t not_equal_mask(T a, T b) { return mask(a != b); }
   static T load(const char *ptr) {
     // NOLINTNEXTLINE(llvmlibc-callee-namespace)
-    return _mm_loadu_si128(reinterpret_cast<__m128i_u const *>(ptr));
+    return __llvm_libc::bit_cast<T>(
+        _mm_loadu_si128(reinterpret_cast<__m128i_u const *>(ptr)));
   }
   static void store(char *ptr, T value) {
     // NOLINTNEXTLINE(llvmlibc-callee-namespace)
-    return _mm_storeu_si128(reinterpret_cast<__m128i_u *>(ptr), value);
+    return _mm_storeu_si128(reinterpret_cast<__m128i_u *>(ptr),
+                            __llvm_libc::bit_cast<__m128i>(value));
   }
   static T get_splatted_value(const char v) {
     const T splatted = {v, v, v, v, v, v, v, v, v, v, v, v, v, v, v, v};
@@ -91,16 +94,18 @@ struct M256 {
   using T = char __attribute__((__vector_size__(SIZE)));
   static uint32_t mask(T value) {
     // NOLINTNEXTLINE(llvmlibc-callee-namespace)
-    return _mm256_movemask_epi8(value);
+    return _mm256_movemask_epi8(__llvm_libc::bit_cast<__m256i>(value));
   }
   static uint32_t not_equal_mask(T a, T b) { return mask(a != b); }
   static T load(const char *ptr) {
     // NOLINTNEXTLINE(llvmlibc-callee-namespace)
-    return _mm256_loadu_si256(reinterpret_cast<__m256i const *>(ptr));
+    return __llvm_libc::bit_cast<T>(
+        _mm256_loadu_si256(reinterpret_cast<__m256i const *>(ptr)));
   }
   static void store(char *ptr, T value) {
     // NOLINTNEXTLINE(llvmlibc-callee-namespace)
-    return _mm256_storeu_si256(reinterpret_cast<__m256i *>(ptr), value);
+    return _mm256_storeu_si256(reinterpret_cast<__m256i *>(ptr),
+                               __llvm_libc::bit_cast<__m256i>(value));
   }
   static T get_splatted_value(const char v) {
     const T splatted = {v, v, v, v, v, v, v, v, v, v, v, v, v, v, v, v,
@@ -117,15 +122,16 @@ struct M512 {
   using T = char __attribute__((__vector_size__(SIZE)));
   static uint64_t not_equal_mask(T a, T b) {
     // NOLINTNEXTLINE(llvmlibc-callee-namespace)
-    return _mm512_cmpneq_epi8_mask(a, b);
+    return _mm512_cmpneq_epi8_mask(__llvm_libc::bit_cast<__m512i>(a),
+                                   __llvm_libc::bit_cast<__m512i>(b));
   }
   static T load(const char *ptr) {
     // NOLINTNEXTLINE(llvmlibc-callee-namespace)
-    return _mm512_loadu_epi8(ptr);
+    return __llvm_libc::bit_cast<T>(_mm512_loadu_epi8(ptr));
   }
   static void store(char *ptr, T value) {
     // NOLINTNEXTLINE(llvmlibc-callee-namespace)
-    return _mm512_storeu_epi8(ptr, value);
+    return _mm512_storeu_epi8(ptr, __llvm_libc::bit_cast<__m512i>(value));
   }
   static T get_splatted_value(const char v) {
     const T splatted = {v, v, v, v, v, v, v, v, v, v, v, v, v, v, v, v,

diff  --git a/libc/test/src/math/NextAfterTest.h b/libc/test/src/math/NextAfterTest.h
index 201d164337048..1554dcb905d4e 100644
--- a/libc/test/src/math/NextAfterTest.h
+++ b/libc/test/src/math/NextAfterTest.h
@@ -9,6 +9,7 @@
 #ifndef LLVM_LIBC_TEST_SRC_MATH_NEXTAFTERTEST_H
 #define LLVM_LIBC_TEST_SRC_MATH_NEXTAFTERTEST_H
 
+#include "src/__support/CPP/Bit.h"
 #include "src/__support/CPP/TypeTraits.h"
 #include "src/__support/FPUtil/BasicOperations.h"
 #include "src/__support/FPUtil/FPBits.h"
@@ -51,54 +52,54 @@ class NextAfterTestTemplate : public __llvm_libc::testing::Test {
     T x = zero;
     T result = func(x, T(1));
     UIntType expected_bits = 1;
-    T expected = *reinterpret_cast<T *>(&expected_bits);
+    T expected = __llvm_libc::bit_cast<T>(expected_bits);
     ASSERT_FP_EQ(result, expected);
 
     result = func(x, T(-1));
     expected_bits = (UIntType(1) << (BIT_WIDTH_OF_TYPE - 1)) + 1;
-    expected = *reinterpret_cast<T *>(&expected_bits);
+    expected = __llvm_libc::bit_cast<T>(expected_bits);
     ASSERT_FP_EQ(result, expected);
 
     x = neg_zero;
     result = func(x, 1);
     expected_bits = 1;
-    expected = *reinterpret_cast<T *>(&expected_bits);
+    expected = __llvm_libc::bit_cast<T>(expected_bits);
     ASSERT_FP_EQ(result, expected);
 
     result = func(x, -1);
     expected_bits = (UIntType(1) << (BIT_WIDTH_OF_TYPE - 1)) + 1;
-    expected = *reinterpret_cast<T *>(&expected_bits);
+    expected = __llvm_libc::bit_cast<T>(expected_bits);
     ASSERT_FP_EQ(result, expected);
 
     // 'from' is max subnormal value.
-    x = *reinterpret_cast<const T *>(&max_subnormal);
+    x = __llvm_libc::bit_cast<T>(max_subnormal);
     result = func(x, 1);
-    expected = *reinterpret_cast<const T *>(&min_normal);
+    expected = __llvm_libc::bit_cast<T>(min_normal);
     ASSERT_FP_EQ(result, expected);
 
     result = func(x, 0);
     expected_bits = max_subnormal - 1;
-    expected = *reinterpret_cast<T *>(&expected_bits);
+    expected = __llvm_libc::bit_cast<T>(expected_bits);
     ASSERT_FP_EQ(result, expected);
 
     x = -x;
 
     result = func(x, -1);
     expected_bits = (UIntType(1) << (BIT_WIDTH_OF_TYPE - 1)) + min_normal;
-    expected = *reinterpret_cast<T *>(&expected_bits);
+    expected = __llvm_libc::bit_cast<T>(expected_bits);
     ASSERT_FP_EQ(result, expected);
 
     result = func(x, 0);
     expected_bits =
         (UIntType(1) << (BIT_WIDTH_OF_TYPE - 1)) + max_subnormal - 1;
-    expected = *reinterpret_cast<T *>(&expected_bits);
+    expected = __llvm_libc::bit_cast<T>(expected_bits);
     ASSERT_FP_EQ(result, expected);
 
     // 'from' is min subnormal value.
-    x = *reinterpret_cast<const T *>(&min_subnormal);
+    x = __llvm_libc::bit_cast<T>(min_subnormal);
     result = func(x, 1);
     expected_bits = min_subnormal + 1;
-    expected = *reinterpret_cast<T *>(&expected_bits);
+    expected = __llvm_libc::bit_cast<T>(expected_bits);
     ASSERT_FP_EQ(result, expected);
     ASSERT_FP_EQ(func(x, 0), 0);
 
@@ -106,35 +107,35 @@ class NextAfterTestTemplate : public __llvm_libc::testing::Test {
     result = func(x, -1);
     expected_bits =
         (UIntType(1) << (BIT_WIDTH_OF_TYPE - 1)) + min_subnormal + 1;
-    expected = *reinterpret_cast<T *>(&expected_bits);
+    expected = __llvm_libc::bit_cast<T>(expected_bits);
     ASSERT_FP_EQ(result, expected);
     ASSERT_FP_EQ(func(x, 0), T(-0.0));
 
     // 'from' is min normal.
-    x = *reinterpret_cast<const T *>(&min_normal);
+    x = __llvm_libc::bit_cast<T>(min_normal);
     result = func(x, 0);
     expected_bits = max_subnormal;
-    expected = *reinterpret_cast<T *>(&expected_bits);
+    expected = __llvm_libc::bit_cast<T>(expected_bits);
     ASSERT_FP_EQ(result, expected);
 
     result = func(x, inf);
     expected_bits = min_normal + 1;
-    expected = *reinterpret_cast<T *>(&expected_bits);
+    expected = __llvm_libc::bit_cast<T>(expected_bits);
     ASSERT_FP_EQ(result, expected);
 
     x = -x;
     result = func(x, 0);
     expected_bits = (UIntType(1) << (BIT_WIDTH_OF_TYPE - 1)) + max_subnormal;
-    expected = *reinterpret_cast<T *>(&expected_bits);
+    expected = __llvm_libc::bit_cast<T>(expected_bits);
     ASSERT_FP_EQ(result, expected);
 
     result = func(x, -inf);
     expected_bits = (UIntType(1) << (BIT_WIDTH_OF_TYPE - 1)) + min_normal + 1;
-    expected = *reinterpret_cast<T *>(&expected_bits);
+    expected = __llvm_libc::bit_cast<T>(expected_bits);
     ASSERT_FP_EQ(result, expected);
 
     // 'from' is max normal and 'to' is infinity.
-    x = *reinterpret_cast<const T *>(&max_normal);
+    x = __llvm_libc::bit_cast<T>(max_normal);
     result = func(x, inf);
     ASSERT_FP_EQ(result, inf);
 
@@ -145,14 +146,14 @@ class NextAfterTestTemplate : public __llvm_libc::testing::Test {
     x = inf;
     result = func(x, 0);
     expected_bits = max_normal;
-    expected = *reinterpret_cast<T *>(&expected_bits);
+    expected = __llvm_libc::bit_cast<T>(expected_bits);
     ASSERT_FP_EQ(result, expected);
     ASSERT_FP_EQ(func(x, inf), inf);
 
     x = neg_inf;
     result = func(x, 0);
     expected_bits = (UIntType(1) << (BIT_WIDTH_OF_TYPE - 1)) + max_normal;
-    expected = *reinterpret_cast<T *>(&expected_bits);
+    expected = __llvm_libc::bit_cast<T>(expected_bits);
     ASSERT_FP_EQ(result, expected);
     ASSERT_FP_EQ(func(x, neg_inf), neg_inf);
 

diff  --git a/libc/test/src/math/SqrtTest.h b/libc/test/src/math/SqrtTest.h
index 79306a60e2baa..0be4f3bd63e29 100644
--- a/libc/test/src/math/SqrtTest.h
+++ b/libc/test/src/math/SqrtTest.h
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "src/__support/CPP/Bit.h"
 #include "utils/MPFRWrapper/MPFRUtils.h"
 #include "utils/UnitTest/FPMatcher.h"
 #include "utils/UnitTest/Test.h"
@@ -47,7 +48,7 @@ template <typename T> class SqrtTest : public __llvm_libc::testing::Test {
     constexpr UIntType COUNT = 1'000'001;
     constexpr UIntType STEP = HIDDEN_BIT / COUNT;
     for (UIntType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
-      T x = *reinterpret_cast<T *>(&v);
+      T x = __llvm_libc::bit_cast<T>(v);
       test_all_rounding_modes(func, x);
     }
   }
@@ -56,7 +57,7 @@ template <typename T> class SqrtTest : public __llvm_libc::testing::Test {
     constexpr UIntType COUNT = 10'000'001;
     constexpr UIntType STEP = UIntType(-1) / COUNT;
     for (UIntType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
-      T x = *reinterpret_cast<T *>(&v);
+      T x = __llvm_libc::bit_cast<T>(v);
       if (isnan(x) || (x < 0)) {
         continue;
       }

diff  --git a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
index 3bccd9d7f894d..f4b5d3747ed35 100644
--- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
@@ -39,6 +39,7 @@ cc_library(
     hdrs = [
         "src/__support/CPP/Array.h",
         "src/__support/CPP/ArrayRef.h",
+        "src/__support/CPP/Bit.h",
         "src/__support/CPP/Bitset.h",
         "src/__support/CPP/Functional.h",
         "src/__support/CPP/Limits.h",
@@ -661,6 +662,7 @@ cc_library(
     ],
     deps = [
         ":__support_common",
+        ":__support_standalone_cpp",
         ":libc_root",
     ],
 )


        


More information about the libc-commits mailing list