[libc-commits] [libc] [llvm] [libc][math] Fix signaling nan handling of hypot(f) and improve hypotf performance. (PR #99432)

via libc-commits libc-commits at lists.llvm.org
Thu Jul 18 14:37:03 PDT 2024


================
@@ -6,66 +6,88 @@
 //
 //===----------------------------------------------------------------------===//
 #include "src/math/hypotf.h"
-#include "src/__support/FPUtil/BasicOperations.h"
+#include "src/__support/FPUtil/FEnvImpl.h"
 #include "src/__support/FPUtil/FPBits.h"
+#include "src/__support/FPUtil/double_double.h"
+#include "src/__support/FPUtil/multiply_add.h"
 #include "src/__support/FPUtil/sqrt.h"
 #include "src/__support/common.h"
 #include "src/__support/macros/config.h"
+#include "src/__support/macros/optimization.h"
 
 namespace LIBC_NAMESPACE_DECL {
 
 LLVM_LIBC_FUNCTION(float, hypotf, (float x, float y)) {
   using DoubleBits = fputil::FPBits<double>;
   using FPBits = fputil::FPBits<float>;
 
-  FPBits x_bits(x), y_bits(y);
+  FPBits x_abs = FPBits(x).abs();
+  FPBits y_abs = FPBits(y).abs();
 
-  uint16_t x_exp = x_bits.get_biased_exponent();
-  uint16_t y_exp = y_bits.get_biased_exponent();
-  uint16_t exp_diff = (x_exp > y_exp) ? (x_exp - y_exp) : (y_exp - x_exp);
+  bool x_abs_larger = x_abs.uintval() >= y_abs.uintval();
 
-  if (exp_diff >= FPBits::FRACTION_LEN + 2) {
-    return fputil::abs(x) + fputil::abs(y);
-  }
+  FPBits a_bits = x_abs_larger ? x_abs : y_abs;
+  FPBits b_bits = x_abs_larger ? y_abs : x_abs;
 
-  double xd = static_cast<double>(x);
-  double yd = static_cast<double>(y);
+  uint32_t a_u = a_bits.uintval();
+  uint32_t b_u = b_bits.uintval();
 
-  // These squares are exact.
-  double x_sq = xd * xd;
-  double y_sq = yd * yd;
+  if (LIBC_UNLIKELY(a_u >= FPBits::EXP_MASK)) {
+    // x or y is inf or nan
+    if (a_bits.is_signaling_nan() || b_bits.is_signaling_nan()) {
+      fputil::raise_except_if_required(FE_INVALID);
+      return FPBits::quiet_nan().get_val();
+    }
+    if (a_bits.is_inf() || b_bits.is_inf())
+      return FPBits::inf().get_val();
+    return a_bits.get_val();
+  }
 
-  // Compute the sum of squares.
-  double sum_sq = x_sq + y_sq;
+  if (LIBC_UNLIKELY(a_u - b_u >=
+                    static_cast<uint32_t>((FPBits::FRACTION_LEN + 2)
+                                          << FPBits::FRACTION_LEN)))
+    return x_abs.get_val() + y_abs.get_val();
 
-  // Compute the rounding error with Fast2Sum algorithm:
-  // x_sq + y_sq = sum_sq - err
-  double err = (x_sq >= y_sq) ? (sum_sq - x_sq) - y_sq : (sum_sq - y_sq) - x_sq;
+  double ad = static_cast<double>(a_bits.get_val());
+  double bd = static_cast<double>(b_bits.get_val());
+
+  // These squares are exact.
+  double a_sq = ad * ad;
+#ifdef LIBC_TARGET_CPU_HAS_FMA
+  double sum_sq = fputil::multiply_add(bd, bd, a_sq);
+#else
+  double b_sq = bd * bd;
+  double sum_sq = a_sq + b_sq;
+#endif
 
   // Take sqrt in double precision.
   DoubleBits result(fputil::sqrt<double>(sum_sq));
+  uint64_t r_u = result.uintval();
 
-  if (!DoubleBits(sum_sq).is_inf_or_nan()) {
-    // Correct rounding.
-    double r_sq = result.get_val() * result.get_val();
-    double diff = sum_sq - r_sq;
-    constexpr uint64_t MASK = 0x0000'0000'3FFF'FFFFULL;
-    uint64_t lrs = result.uintval() & MASK;
-
-    if (lrs == 0x0000'0000'1000'0000ULL && err < diff) {
-      result.set_uintval(result.uintval() | 1ULL);
-    } else if (lrs == 0x0000'0000'3000'0000ULL && err > diff) {
-      result.set_uintval(result.uintval() - 1ULL);
-    }
-  } else {
-    FPBits bits_x(x), bits_y(y);
-    if (bits_x.is_inf_or_nan() || bits_y.is_inf_or_nan()) {
-      if (bits_x.is_inf() || bits_y.is_inf())
-        return FPBits::inf().get_val();
-      if (bits_x.is_nan())
-        return x;
-      return y;
+  // If any of the sticky bits of the result are non-zero, except the LSB, then
+  // the rounded result is correct.
+  if (LIBC_UNLIKELY(((r_u + 1) & 0x0000'0000'0FFF'FFFE) == 0)) {
+    double r_d = result.get_val();
+
+    // Perform rounding correction.
+#ifdef LIBC_TARGET_CPU_HAS_FMA
+    double sum_sq_lo = fputil::multiply_add(bd, bd, a_sq - sum_sq);
+    double err = sum_sq_lo - fputil::multiply_add(r_d, r_d, -sum_sq);
+#else
+    fputil::DoubleDouble r_sq = fputil::exact_mult(r_d, r_d);
+    double sum_sq_lo = b_sq - (sum_sq - a_sq);
+    double err = (sum_sq - r_sq.hi) + (sum_sq_lo - r_sq.lo);
+#endif
+
+    if (err > 0)
+      r_u |= 1;
+    else if ((err < 0) && (r_u & 1) == 0)
+      r_u -= 1;
+    else if ((r_u & 0x0000'0000'1FFF'FFFF) == 0) {
+      // The rounded result is exact.
+      fputil::clear_except_if_required(FE_INEXACT);
     }
----------------
overmighty wrote:

Relevant example of the LLVM code style:

```cpp
// Use braces for the `if` block to keep it uniform with the `else` block.
if (isa<FunctionDecl>(D)) {
  handleFunctionDecl(D);
} else {
  // In this `else` case, it is necessary that we explain the situation with
  // this surprisingly long comment, so it would be unclear without the braces
  // whether the following statement is in the scope of the `if`.
  handleOtherDecl(D);
}
```

https://llvm.org/docs/CodingStandards.html#don-t-use-braces-on-simple-single-statement-bodies-of-if-else-loop-statements

```suggestion
    if (err > 0) {
      r_u |= 1;
    } else if ((err < 0) && (r_u & 1) == 0) {
      r_u -= 1;
    } else if ((r_u & 0x0000'0000'1FFF'FFFF) == 0) {
      // The rounded result is exact.
      fputil::clear_except_if_required(FE_INEXACT);
    }
```


https://github.com/llvm/llvm-project/pull/99432


More information about the libc-commits mailing list