[llvm-branch-commits] [libcxx] release/19.x: [libc++][math] Fix undue overflowing of `std::hypot(x, y, z)` (#93350) (PR #100141)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jul 23 08:17:13 PDT 2024


https://github.com/llvmbot created https://github.com/llvm/llvm-project/pull/100141

Backport 9628777

Requested by: @ldionne

>From f281cb2886edb46067606b62163e0c3d6cdfd965 Mon Sep 17 00:00:00 2001
From: PaulXiCao <paulxicao7 at gmail.com>
Date: Tue, 23 Jul 2024 15:11:44 +0000
Subject: [PATCH] [libc++][math] Fix undue overflowing of `std::hypot(x,y,z)`
 (#93350)

The 3-dimentionsional `std::hypot(x,y,z)` was sub-optimally implemented.
This lead to possible over-/underflows in (intermediate) results which
can be circumvented by this proposed change.

The idea is to to scale the arguments (see linked issue for full
discussion).

Tests have been added for problematic over- and underflows.

Closes #92782

(cherry picked from commit 9628777479a970db5d0c2d0b456dac6633864760)
---
 libcxx/include/__math/hypot.h                 | 89 ++++++++++++++++++
 libcxx/include/cmath                          | 25 +----
 .../test/libcxx/transitive_includes/cxx17.csv |  3 +
 .../test/libcxx/transitive_includes/cxx20.csv |  3 +
 .../test/libcxx/transitive_includes/cxx23.csv |  3 +
 .../test/libcxx/transitive_includes/cxx26.csv |  3 +
 .../test/std/numerics/c.math/cmath.pass.cpp   | 91 +++++++++++++++----
 libcxx/test/support/fp_compare.h              | 45 ++++-----
 8 files changed, 197 insertions(+), 65 deletions(-)

diff --git a/libcxx/include/__math/hypot.h b/libcxx/include/__math/hypot.h
index 1bf193a9ab7ee..61fd260c59409 100644
--- a/libcxx/include/__math/hypot.h
+++ b/libcxx/include/__math/hypot.h
@@ -15,10 +15,21 @@
 #include <__type_traits/is_same.h>
 #include <__type_traits/promote.h>
 
+#if _LIBCPP_STD_VER >= 17
+#  include <__algorithm/max.h>
+#  include <__math/abs.h>
+#  include <__math/roots.h>
+#  include <__utility/pair.h>
+#  include <limits>
+#endif
+
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
 #  pragma GCC system_header
 #endif
 
+_LIBCPP_PUSH_MACROS
+#include <__undef_macros>
+
 _LIBCPP_BEGIN_NAMESPACE_STD
 
 namespace __math {
@@ -41,8 +52,86 @@ inline _LIBCPP_HIDE_FROM_ABI typename __promote<_A1, _A2>::type hypot(_A1 __x, _
   return __math::hypot((__result_type)__x, (__result_type)__y);
 }
 
+#if _LIBCPP_STD_VER >= 17
+// Factors needed to determine if over-/underflow might happen for `std::hypot(x,y,z)`.
+// returns [overflow_threshold, overflow_scale]
+template <class _Real>
+_LIBCPP_HIDE_FROM_ABI std::pair<_Real, _Real> __hypot_factors() {
+  static_assert(std::numeric_limits<_Real>::is_iec559);
+
+  if constexpr (std::is_same_v<_Real, float>) {
+    static_assert(-125 == std::numeric_limits<_Real>::min_exponent);
+    static_assert(+128 == std::numeric_limits<_Real>::max_exponent);
+    return {0x1.0p+62f, 0x1.0p-70f};
+  } else if constexpr (std::is_same_v<_Real, double>) {
+    static_assert(-1021 == std::numeric_limits<_Real>::min_exponent);
+    static_assert(+1024 == std::numeric_limits<_Real>::max_exponent);
+    return {0x1.0p+510, 0x1.0p-600};
+  } else { // long double
+    static_assert(std::is_same_v<_Real, long double>);
+
+    // preprocessor guard necessary, otherwise literals (e.g. `0x1.0p+8'190l`) throw warnings even when shielded by `if
+    // constexpr`
+#  if __DBL_MAX_EXP__ == __LDBL_MAX_EXP__
+    static_assert(sizeof(_Real) == sizeof(double));
+    return static_cast<std::pair<_Real, _Real>>(__math::__hypot_factors<double>());
+#  else
+    static_assert(sizeof(_Real) > sizeof(double));
+    static_assert(-16381 == std::numeric_limits<_Real>::min_exponent);
+    static_assert(+16384 == std::numeric_limits<_Real>::max_exponent);
+    return {0x1.0p+8190l, 0x1.0p-9000l};
+#  endif
+  }
+}
+
+// Computes the three-dimensional hypotenuse: `std::hypot(x,y,z)`.
+// The naive implementation might over-/underflow which is why this implementation is more involved:
+//    If the square of an argument might run into issues, we scale the arguments appropriately.
+// See https://github.com/llvm/llvm-project/issues/92782 for a detailed discussion and summary.
+template <class _Real>
+_LIBCPP_HIDE_FROM_ABI _Real __hypot(_Real __x, _Real __y, _Real __z) {
+  const _Real __max_abs = std::max(__math::fabs(__x), std::max(__math::fabs(__y), __math::fabs(__z)));
+  const auto [__overflow_threshold, __overflow_scale] = __math::__hypot_factors<_Real>();
+  _Real __scale;
+  if (__max_abs > __overflow_threshold) { // x*x + y*y + z*z might overflow
+    __scale = __overflow_scale;
+    __x *= __scale;
+    __y *= __scale;
+    __z *= __scale;
+  } else if (__max_abs < 1 / __overflow_threshold) { // x*x + y*y + z*z might underflow
+    __scale = 1 / __overflow_scale;
+    __x *= __scale;
+    __y *= __scale;
+    __z *= __scale;
+  } else
+    __scale = 1;
+  return __math::sqrt(__x * __x + __y * __y + __z * __z) / __scale;
+}
+
+inline _LIBCPP_HIDE_FROM_ABI float hypot(float __x, float __y, float __z) { return __math::__hypot(__x, __y, __z); }
+
+inline _LIBCPP_HIDE_FROM_ABI double hypot(double __x, double __y, double __z) { return __math::__hypot(__x, __y, __z); }
+
+inline _LIBCPP_HIDE_FROM_ABI long double hypot(long double __x, long double __y, long double __z) {
+  return __math::__hypot(__x, __y, __z);
+}
+
+template <class _A1,
+          class _A2,
+          class _A3,
+          std::enable_if_t< is_arithmetic_v<_A1> && is_arithmetic_v<_A2> && is_arithmetic_v<_A3>, int> = 0 >
+_LIBCPP_HIDE_FROM_ABI typename __promote<_A1, _A2, _A3>::type hypot(_A1 __x, _A2 __y, _A3 __z) _NOEXCEPT {
+  using __result_type = typename __promote<_A1, _A2, _A3>::type;
+  static_assert(!(
+      std::is_same_v<_A1, __result_type> && std::is_same_v<_A2, __result_type> && std::is_same_v<_A3, __result_type>));
+  return __math::__hypot(
+      static_cast<__result_type>(__x), static_cast<__result_type>(__y), static_cast<__result_type>(__z));
+}
+#endif
+
 } // namespace __math
 
 _LIBCPP_END_NAMESPACE_STD
+_LIBCPP_POP_MACROS
 
 #endif // _LIBCPP___MATH_HYPOT_H
diff --git a/libcxx/include/cmath b/libcxx/include/cmath
index 3c22604a683c3..6480c4678ce33 100644
--- a/libcxx/include/cmath
+++ b/libcxx/include/cmath
@@ -313,6 +313,7 @@ constexpr long double lerp(long double a, long double b, long double t) noexcept
 */
 
 #include <__config>
+#include <__math/hypot.h>
 #include <__type_traits/enable_if.h>
 #include <__type_traits/is_arithmetic.h>
 #include <__type_traits/is_constant_evaluated.h>
@@ -553,30 +554,6 @@ using ::scalbnl _LIBCPP_USING_IF_EXISTS;
 using ::tgammal _LIBCPP_USING_IF_EXISTS;
 using ::truncl _LIBCPP_USING_IF_EXISTS;
 
-#if _LIBCPP_STD_VER >= 17
-inline _LIBCPP_HIDE_FROM_ABI float hypot(float __x, float __y, float __z) {
-  return sqrt(__x * __x + __y * __y + __z * __z);
-}
-inline _LIBCPP_HIDE_FROM_ABI double hypot(double __x, double __y, double __z) {
-  return sqrt(__x * __x + __y * __y + __z * __z);
-}
-inline _LIBCPP_HIDE_FROM_ABI long double hypot(long double __x, long double __y, long double __z) {
-  return sqrt(__x * __x + __y * __y + __z * __z);
-}
-
-template <class _A1, class _A2, class _A3>
-inline _LIBCPP_HIDE_FROM_ABI
-typename enable_if_t< is_arithmetic<_A1>::value && is_arithmetic<_A2>::value && is_arithmetic<_A3>::value,
-                      __promote<_A1, _A2, _A3> >::type
-hypot(_A1 __lcpp_x, _A2 __lcpp_y, _A3 __lcpp_z) _NOEXCEPT {
-  typedef typename __promote<_A1, _A2, _A3>::type __result_type;
-  static_assert(
-      !(is_same<_A1, __result_type>::value && is_same<_A2, __result_type>::value && is_same<_A3, __result_type>::value),
-      "");
-  return std::hypot((__result_type)__lcpp_x, (__result_type)__lcpp_y, (__result_type)__lcpp_z);
-}
-#endif
-
 template <class _A1, __enable_if_t<is_floating_point<_A1>::value, int> = 0>
 _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR bool __constexpr_isnan(_A1 __lcpp_x) _NOEXCEPT {
 #if __has_builtin(__builtin_isnan)
diff --git a/libcxx/test/libcxx/transitive_includes/cxx17.csv b/libcxx/test/libcxx/transitive_includes/cxx17.csv
index 2c028462144ee..8099d2b79c4be 100644
--- a/libcxx/test/libcxx/transitive_includes/cxx17.csv
+++ b/libcxx/test/libcxx/transitive_includes/cxx17.csv
@@ -130,6 +130,9 @@ chrono type_traits
 chrono vector
 chrono version
 cinttypes cstdint
+cmath cstddef
+cmath cstdint
+cmath initializer_list
 cmath limits
 cmath type_traits
 cmath version
diff --git a/libcxx/test/libcxx/transitive_includes/cxx20.csv b/libcxx/test/libcxx/transitive_includes/cxx20.csv
index 982c2013e3417..384e51b101f31 100644
--- a/libcxx/test/libcxx/transitive_includes/cxx20.csv
+++ b/libcxx/test/libcxx/transitive_includes/cxx20.csv
@@ -135,6 +135,9 @@ chrono type_traits
 chrono vector
 chrono version
 cinttypes cstdint
+cmath cstddef
+cmath cstdint
+cmath initializer_list
 cmath limits
 cmath type_traits
 cmath version
diff --git a/libcxx/test/libcxx/transitive_includes/cxx23.csv b/libcxx/test/libcxx/transitive_includes/cxx23.csv
index 8ffb71d8b566b..46b833d143f39 100644
--- a/libcxx/test/libcxx/transitive_includes/cxx23.csv
+++ b/libcxx/test/libcxx/transitive_includes/cxx23.csv
@@ -83,6 +83,9 @@ chrono string_view
 chrono vector
 chrono version
 cinttypes cstdint
+cmath cstddef
+cmath cstdint
+cmath initializer_list
 cmath limits
 cmath version
 codecvt cctype
diff --git a/libcxx/test/libcxx/transitive_includes/cxx26.csv b/libcxx/test/libcxx/transitive_includes/cxx26.csv
index 8ffb71d8b566b..46b833d143f39 100644
--- a/libcxx/test/libcxx/transitive_includes/cxx26.csv
+++ b/libcxx/test/libcxx/transitive_includes/cxx26.csv
@@ -83,6 +83,9 @@ chrono string_view
 chrono vector
 chrono version
 cinttypes cstdint
+cmath cstddef
+cmath cstdint
+cmath initializer_list
 cmath limits
 cmath version
 codecvt cctype
diff --git a/libcxx/test/std/numerics/c.math/cmath.pass.cpp b/libcxx/test/std/numerics/c.math/cmath.pass.cpp
index 9379084499792..19b5fd0cf8996 100644
--- a/libcxx/test/std/numerics/c.math/cmath.pass.cpp
+++ b/libcxx/test/std/numerics/c.math/cmath.pass.cpp
@@ -12,14 +12,17 @@
 
 // <cmath>
 
+#include <array>
 #include <cmath>
 #include <limits>
 #include <type_traits>
 #include <cassert>
 
+#include "fp_compare.h"
 #include "test_macros.h"
 #include "hexfloat.h"
 #include "truncate_fp.h"
+#include "type_algorithms.h"
 
 // convertible to int/float/double/etc
 template <class T, int N=0>
@@ -1113,6 +1116,56 @@ void test_fmin()
     assert(std::fmin(1,0) == 0);
 }
 
+#if TEST_STD_VER >= 17
+struct TestHypot3 {
+  template <class Real>
+  void operator()() const {
+    const auto check = [](Real elem, Real abs_tol) {
+      assert(std::isfinite(std::hypot(elem, Real(0), Real(0))));
+      assert(fptest_close(std::hypot(elem, Real(0), Real(0)), elem, abs_tol));
+      assert(std::isfinite(std::hypot(elem, elem, Real(0))));
+      assert(fptest_close(std::hypot(elem, elem, Real(0)), std::sqrt(Real(2)) * elem, abs_tol));
+      assert(std::isfinite(std::hypot(elem, elem, elem)));
+      assert(fptest_close(std::hypot(elem, elem, elem), std::sqrt(Real(3)) * elem, abs_tol));
+    };
+
+    { // check for overflow
+      const auto [elem, abs_tol] = []() -> std::array<Real, 2> {
+        if constexpr (std::is_same_v<Real, float>)
+          return {1e20f, 1e16f};
+        else if constexpr (std::is_same_v<Real, double>)
+          return {1e300, 1e287};
+        else { // long double
+#  if __DBL_MAX_EXP__ == __LDBL_MAX_EXP__
+          return {1e300l, 1e287l}; // 64-bit
+#  else
+          return {1e4000l, 1e3985l}; // 80- or 128-bit
+#  endif
+        }
+      }();
+      check(elem, abs_tol);
+    }
+
+    { // check for underflow
+      const auto [elem, abs_tol] = []() -> std::array<Real, 2> {
+        if constexpr (std::is_same_v<Real, float>)
+          return {1e-20f, 1e-24f};
+        else if constexpr (std::is_same_v<Real, double>)
+          return {1e-287, 1e-300};
+        else { // long double
+#  if __DBL_MAX_EXP__ == __LDBL_MAX_EXP__
+          return {1e-287l, 1e-300l}; // 64-bit
+#  else
+          return {1e-3985l, 1e-4000l}; // 80- or 128-bit
+#  endif
+        }
+      }();
+      check(elem, abs_tol);
+    }
+  }
+};
+#endif
+
 void test_hypot()
 {
     static_assert((std::is_same<decltype(std::hypot((float)0, (float)0)), float>::value), "");
@@ -1135,25 +1188,31 @@ void test_hypot()
     static_assert((std::is_same<decltype(hypot(Ambiguous(), Ambiguous())), Ambiguous>::value), "");
     assert(std::hypot(3,4) == 5);
 
-#if TEST_STD_VER > 14
-    static_assert((std::is_same<decltype(std::hypot((float)0, (float)0, (float)0)), float>::value), "");
-    static_assert((std::is_same<decltype(std::hypot((float)0, (bool)0, (float)0)), double>::value), "");
-    static_assert((std::is_same<decltype(std::hypot((float)0, (unsigned short)0, (double)0)), double>::value), "");
-    static_assert((std::is_same<decltype(std::hypot((float)0, (int)0, (long double)0)), long double>::value), "");
-    static_assert((std::is_same<decltype(std::hypot((float)0, (double)0, (long)0)), double>::value), "");
-    static_assert((std::is_same<decltype(std::hypot((float)0, (long double)0, (unsigned long)0)), long double>::value), "");
-    static_assert((std::is_same<decltype(std::hypot((float)0, (int)0, (long long)0)), double>::value), "");
-    static_assert((std::is_same<decltype(std::hypot((float)0, (int)0, (unsigned long long)0)), double>::value), "");
-    static_assert((std::is_same<decltype(std::hypot((float)0, (double)0, (double)0)), double>::value), "");
-    static_assert((std::is_same<decltype(std::hypot((float)0, (long double)0, (long double)0)), long double>::value), "");
-    static_assert((std::is_same<decltype(std::hypot((float)0, (float)0, (double)0)), double>::value), "");
-    static_assert((std::is_same<decltype(std::hypot((float)0, (float)0, (long double)0)), long double>::value), "");
-    static_assert((std::is_same<decltype(std::hypot((float)0, (double)0, (long double)0)), long double>::value), "");
-    static_assert((std::is_same<decltype(std::hypot((int)0, (int)0, (int)0)), double>::value), "");
-    static_assert((std::is_same<decltype(hypot(Ambiguous(), Ambiguous(), Ambiguous())), Ambiguous>::value), "");
+#if TEST_STD_VER >= 17
+    // clang-format off
+    static_assert((std::is_same_v<decltype(std::hypot((float)0, (float)0,          (float)0)),              float>));
+    static_assert((std::is_same_v<decltype(std::hypot((float)0, (bool)0,           (float)0)),              double>));
+    static_assert((std::is_same_v<decltype(std::hypot((float)0, (unsigned short)0, (double)0)),             double>));
+    static_assert((std::is_same_v<decltype(std::hypot((float)0, (int)0,            (long double)0)),        long double>));
+    static_assert((std::is_same_v<decltype(std::hypot((float)0, (double)0,         (long)0)),               double>));
+    static_assert((std::is_same_v<decltype(std::hypot((float)0, (long double)0,    (unsigned long)0)),      long double>));
+    static_assert((std::is_same_v<decltype(std::hypot((float)0, (int)0,            (long long)0)),          double>));
+    static_assert((std::is_same_v<decltype(std::hypot((float)0, (int)0,            (unsigned long long)0)), double>));
+    static_assert((std::is_same_v<decltype(std::hypot((float)0, (double)0,         (double)0)),             double>));
+    static_assert((std::is_same_v<decltype(std::hypot((float)0, (long double)0,    (long double)0)),        long double>));
+    static_assert((std::is_same_v<decltype(std::hypot((float)0, (float)0,          (double)0)),             double>));
+    static_assert((std::is_same_v<decltype(std::hypot((float)0, (float)0,          (long double)0)),        long double>));
+    static_assert((std::is_same_v<decltype(std::hypot((float)0, (double)0,         (long double)0)),        long double>));
+    static_assert((std::is_same_v<decltype(std::hypot((int)0,   (int)0,            (int)0)),                double>));
+    static_assert((std::is_same_v<decltype(hypot(Ambiguous(), Ambiguous(), Ambiguous())), Ambiguous>));
+    // clang-format on
 
     assert(std::hypot(2,3,6) == 7);
     assert(std::hypot(1,4,8) == 9);
+
+    // Check for undue over-/underflows of intermediate results.
+    // See discussion at https://github.com/llvm/llvm-project/issues/92782.
+    types::for_each(types::floating_point_types(), TestHypot3());
 #endif
 }
 
diff --git a/libcxx/test/support/fp_compare.h b/libcxx/test/support/fp_compare.h
index 1d1933b0bcd81..3088a211dadc3 100644
--- a/libcxx/test/support/fp_compare.h
+++ b/libcxx/test/support/fp_compare.h
@@ -9,39 +9,34 @@
 #ifndef SUPPORT_FP_COMPARE_H
 #define SUPPORT_FP_COMPARE_H
 
-#include <cmath>      // for std::abs
-#include <algorithm>  // for std::max
+#include <cmath>     // for std::abs
+#include <algorithm> // for std::max
 #include <cassert>
+#include <__config>
 
 // See https://www.boost.org/doc/libs/1_70_0/libs/test/doc/html/boost_test/testing_tools/extended_comparison/floating_point/floating_points_comparison_theory.html
 
-template<typename T>
-bool fptest_close(T val, T expected, T eps)
-{
-    constexpr T zero = T(0);
-    assert(eps >= zero);
+template <typename T>
+bool fptest_close(T val, T expected, T eps) {
+  _LIBCPP_CONSTEXPR T zero = T(0);
+  assert(eps >= zero);
 
-    // Handle the zero cases
-    if (eps      == zero) return val == expected;
-    if (val      == zero) return std::abs(expected) <= eps;
-    if (expected == zero) return std::abs(val)      <= eps;
+  // Handle the zero cases
+  if (eps == zero)
+    return val == expected;
+  if (val == zero)
+    return std::abs(expected) <= eps;
+  if (expected == zero)
+    return std::abs(val) <= eps;
 
-    return std::abs(val - expected) < eps
-        && std::abs(val - expected)/std::abs(val) < eps;
+  return std::abs(val - expected) < eps && std::abs(val - expected) / std::abs(val) < eps;
 }
 
-template<typename T>
-bool fptest_close_pct(T val, T expected, T percent)
-{
-    constexpr T zero = T(0);
-    assert(percent >= zero);
-
-    // Handle the zero cases
-    if (percent == zero) return val == expected;
-    T eps = (percent / T(100)) * std::max(std::abs(val), std::abs(expected));
-
-    return fptest_close(val, expected, eps);
+template <typename T>
+bool fptest_close_pct(T val, T expected, T percent) {
+  assert(percent >= T(0));
+  T eps = (percent / T(100)) * std::max(std::abs(val), std::abs(expected));
+  return fptest_close(val, expected, eps);
 }
 
-
 #endif // SUPPORT_FP_COMPARE_H



More information about the llvm-branch-commits mailing list