[libc-commits] [libc] [libc][math] Update getpayload and fmul with NaN inputs. (PR #99812)

via libc-commits libc-commits at lists.llvm.org
Sun Jul 21 14:43:25 PDT 2024


https://github.com/lntue updated https://github.com/llvm/llvm-project/pull/99812

>From 2c7d580f7d84ae54d971499b681e9fe05fc06928 Mon Sep 17 00:00:00 2001
From: Tue Ly <lntue.h at gmail.com>
Date: Sun, 21 Jul 2024 15:55:26 +0000
Subject: [PATCH 1/2] [libc][math] Update getpayload and fmul with NaN inputs.

---
 libc/src/__support/FPUtil/BasicOperations.h |  7 ++-
 libc/src/__support/FPUtil/CMakeLists.txt    | 57 +++++++++++----------
 libc/src/__support/FPUtil/generic/mul.h     | 20 ++++----
 libc/test/src/math/smoke/MulTest.h          | 19 +------
 4 files changed, 46 insertions(+), 57 deletions(-)

diff --git a/libc/src/__support/FPUtil/BasicOperations.h b/libc/src/__support/FPUtil/BasicOperations.h
index a963a92bfb074..40f37b281b02c 100644
--- a/libc/src/__support/FPUtil/BasicOperations.h
+++ b/libc/src/__support/FPUtil/BasicOperations.h
@@ -11,8 +11,8 @@
 
 #include "FEnvImpl.h"
 #include "FPBits.h"
+#include "dyadic_float.h"
 
-#include "FEnvImpl.h"
 #include "src/__support/CPP/type_traits.h"
 #include "src/__support/common.h"
 #include "src/__support/macros/config.h"
@@ -274,7 +274,10 @@ LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> getpayload(T x) {
   if (!x_bits.is_nan())
     return T(-1.0);
 
-  return T(x_bits.uintval() & (FPBits::FRACTION_MASK >> 1));
+  DyadicFloat<FPBits::STORAGE_LEN> payload(
+      Sign::POS, 0, x_bits.uintval() & (FPBits::FRACTION_MASK >> 1));
+
+  return static_cast<T>(payload);
 }
 
 template <bool IsSignaling, typename T>
diff --git a/libc/src/__support/FPUtil/CMakeLists.txt b/libc/src/__support/FPUtil/CMakeLists.txt
index 793d3a121c742..8804f3a4d5e23 100644
--- a/libc/src/__support/FPUtil/CMakeLists.txt
+++ b/libc/src/__support/FPUtil/CMakeLists.txt
@@ -75,19 +75,6 @@ add_header_library(
     libc.src.__support.common
 )
 
-add_header_library(
-  basic_operations
-  HDRS
-    BasicOperations.h
-  DEPENDS
-    .fp_bits
-    .fenv_impl
-    libc.src.__support.CPP.type_traits
-    libc.src.__support.uint128
-    libc.src.__support.common
-    libc.src.__support.macros.optimization
-)
-
 add_header_library(
   division_and_remainder_operations
   HDRS
@@ -113,21 +100,6 @@ add_header_library(
 )
 
 
-add_header_library(
-  hypot
-  HDRS
-    Hypot.h
-  DEPENDS
-    .basic_operations
-    .fenv_impl
-    .fp_bits
-    .rounding_mode
-    libc.src.__support.common
-    libc.src.__support.CPP.bit
-    libc.src.__support.CPP.type_traits
-    libc.src.__support.uint128
-)
-
 add_header_library(
   sqrt
   HDRS
@@ -208,6 +180,35 @@ add_header_library(
     libc.src.__support.macros.optimization
 )
 
+add_header_library(
+  basic_operations
+  HDRS
+    BasicOperations.h
+  DEPENDS
+    .dyadic_float
+    .fp_bits
+    .fenv_impl
+    libc.src.__support.CPP.type_traits
+    libc.src.__support.uint128
+    libc.src.__support.common
+    libc.src.__support.macros.optimization
+)
+
+add_header_library(
+  hypot
+  HDRS
+    Hypot.h
+  DEPENDS
+    .basic_operations
+    .fenv_impl
+    .fp_bits
+    .rounding_mode
+    libc.src.__support.common
+    libc.src.__support.CPP.bit
+    libc.src.__support.CPP.type_traits
+    libc.src.__support.uint128
+)
+
 add_header_library(
   manipulation_functions
   HDRS
diff --git a/libc/src/__support/FPUtil/generic/mul.h b/libc/src/__support/FPUtil/generic/mul.h
index 02fc69c6cb1ba..61be3719e123a 100644
--- a/libc/src/__support/FPUtil/generic/mul.h
+++ b/libc/src/__support/FPUtil/generic/mul.h
@@ -50,19 +50,19 @@ mul(InType x, InType y) {
         raise_except_if_required(FE_INVALID);
 
       if (x_bits.is_quiet_nan()) {
-        InStorageType x_payload = static_cast<InStorageType>(getpayload(x));
-        if ((x_payload & ~(OutFPBits::FRACTION_MASK >> 1)) == 0)
-          return OutFPBits::quiet_nan(x_bits.sign(),
-                                      static_cast<OutStorageType>(x_payload))
-              .get_val();
+        InStorageType x_payload = x_bits.get_mantissa();
+        x_payload >>= (InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN);
+        return OutFPBits::quiet_nan(x_bits.sign(),
+                                    static_cast<OutStorageType>(x_payload))
+            .get_val();
       }
 
       if (y_bits.is_quiet_nan()) {
-        InStorageType y_payload = static_cast<InStorageType>(getpayload(y));
-        if ((y_payload & ~(OutFPBits::FRACTION_MASK >> 1)) == 0)
-          return OutFPBits::quiet_nan(y_bits.sign(),
-                                      static_cast<OutStorageType>(y_payload))
-              .get_val();
+        InStorageType y_payload = y_bits.get_mantissa();
+        y_payload >>= (InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN);
+        return OutFPBits::quiet_nan(y_bits.sign(),
+                                    static_cast<OutStorageType>(y_payload))
+            .get_val();
       }
 
       return OutFPBits::quiet_nan().get_val();
diff --git a/libc/test/src/math/smoke/MulTest.h b/libc/test/src/math/smoke/MulTest.h
index e2298eaeeb216..0c847e39687b7 100644
--- a/libc/test/src/math/smoke/MulTest.h
+++ b/libc/test/src/math/smoke/MulTest.h
@@ -38,23 +38,8 @@ class MulTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
     EXPECT_FP_IS_NAN_WITH_EXCEPTION(func(sNaN, sNaN), FE_INVALID);
 
     InType qnan_42 = InFPBits::quiet_nan(Sign::POS, 0x42).get_val();
-    EXPECT_FP_EQ(InType(0x42.0p+0),
-                 LIBC_NAMESPACE::fputil::getpayload(func(qnan_42, zero)));
-    EXPECT_FP_EQ(InType(0x42.0p+0),
-                 LIBC_NAMESPACE::fputil::getpayload(func(zero, qnan_42)));
-
-    if constexpr (sizeof(OutType) < sizeof(InType)) {
-      InStorageType max_payload = InFPBits::FRACTION_MASK >> 1;
-      InType qnan_max = InFPBits::quiet_nan(Sign::POS, max_payload).get_val();
-      EXPECT_FP_EQ(zero,
-                   LIBC_NAMESPACE::fputil::getpayload(func(qnan_max, zero)));
-      EXPECT_FP_EQ(zero,
-                   LIBC_NAMESPACE::fputil::getpayload(func(zero, qnan_max)));
-      EXPECT_FP_EQ(InType(0x42.0p+0),
-                   LIBC_NAMESPACE::fputil::getpayload(func(qnan_max, qnan_42)));
-      EXPECT_FP_EQ(InType(0x42.0p+0),
-                   LIBC_NAMESPACE::fputil::getpayload(func(qnan_42, qnan_max)));
-    }
+    EXPECT_FP_IS_NAN(func(qnan_42, zero));
+    EXPECT_FP_IS_NAN(func(zero, qnan_42));
 
     EXPECT_FP_EQ(inf, func(inf, InType(1.0)));
     EXPECT_FP_EQ(neg_inf, func(neg_inf, InType(1.0)));

>From 36e7ca40049922fe7e77275b16f75b423d07332a Mon Sep 17 00:00:00 2001
From: Tue Ly <lntue.h at gmail.com>
Date: Sun, 21 Jul 2024 21:42:38 +0000
Subject: [PATCH 2/2] Address comments and fix the same issue for add_sub, div,
 and FMA.

---
 libc/src/__support/FPUtil/BasicOperations.h | 12 ++++--
 libc/src/__support/FPUtil/generic/FMA.h     | 43 ++++++++++++---------
 libc/src/__support/FPUtil/generic/add_sub.h | 26 +++++++------
 libc/src/__support/FPUtil/generic/div.h     | 20 +++++-----
 libc/src/__support/FPUtil/generic/mul.h     |  4 +-
 5 files changed, 60 insertions(+), 45 deletions(-)

diff --git a/libc/src/__support/FPUtil/BasicOperations.h b/libc/src/__support/FPUtil/BasicOperations.h
index 40f37b281b02c..3b7d7a5c249ae 100644
--- a/libc/src/__support/FPUtil/BasicOperations.h
+++ b/libc/src/__support/FPUtil/BasicOperations.h
@@ -269,15 +269,21 @@ totalordermag(T x, T y) {
 template <typename T>
 LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> getpayload(T x) {
   using FPBits = FPBits<T>;
+  using StorageType = typename FPBits::StorageType;
   FPBits x_bits(x);
 
   if (!x_bits.is_nan())
     return T(-1.0);
 
-  DyadicFloat<FPBits::STORAGE_LEN> payload(
-      Sign::POS, 0, x_bits.uintval() & (FPBits::FRACTION_MASK >> 1));
+  StorageType payload = x_bits.uintval() & (FPBits::FRACTION_MASK >> 1);
+
+  if constexpr (is_big_int_v<StorageType>) {
+    DyadicFloat<FPBits::STORAGE_LEN> payload_dfloat(Sign::POS, 0, payload);
 
-  return static_cast<T>(payload);
+    return static_cast<T>(payload_dfloat);
+  } else {
+    return static_cast<T>(payload);
+  }
 }
 
 template <bool IsSignaling, typename T>
diff --git a/libc/src/__support/FPUtil/generic/FMA.h b/libc/src/__support/FPUtil/generic/FMA.h
index 337301d86d919..87806b20f8cfa 100644
--- a/libc/src/__support/FPUtil/generic/FMA.h
+++ b/libc/src/__support/FPUtil/generic/FMA.h
@@ -129,27 +129,27 @@ fma(InType x, InType y, InType z) {
         raise_except_if_required(FE_INVALID);
 
       if (x_bits.is_quiet_nan()) {
-        InStorageType x_payload = static_cast<InStorageType>(getpayload(x));
-        if ((x_payload & ~(OutFPBits::FRACTION_MASK >> 1)) == 0)
-          return OutFPBits::quiet_nan(x_bits.sign(),
-                                      static_cast<OutStorageType>(x_payload))
-              .get_val();
+        InStorageType x_payload = x_bits.get_mantissa();
+        x_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
+        return OutFPBits::quiet_nan(x_bits.sign(),
+                                    static_cast<OutStorageType>(x_payload))
+            .get_val();
       }
 
       if (y_bits.is_quiet_nan()) {
-        InStorageType y_payload = static_cast<InStorageType>(getpayload(y));
-        if ((y_payload & ~(OutFPBits::FRACTION_MASK >> 1)) == 0)
-          return OutFPBits::quiet_nan(y_bits.sign(),
-                                      static_cast<OutStorageType>(y_payload))
-              .get_val();
+        InStorageType y_payload = y_bits.get_mantissa();
+        y_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
+        return OutFPBits::quiet_nan(y_bits.sign(),
+                                    static_cast<OutStorageType>(y_payload))
+            .get_val();
       }
 
       if (z_bits.is_quiet_nan()) {
-        InStorageType z_payload = static_cast<InStorageType>(getpayload(z));
-        if ((z_payload & ~(OutFPBits::FRACTION_MASK >> 1)) == 0)
-          return OutFPBits::quiet_nan(z_bits.sign(),
-                                      static_cast<OutStorageType>(z_payload))
-              .get_val();
+        InStorageType z_payload = z_bits.get_mantissa();
+        z_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
+        return OutFPBits::quiet_nan(z_bits.sign(),
+                                    static_cast<OutStorageType>(z_payload))
+            .get_val();
       }
 
       return OutFPBits::quiet_nan().get_val();
@@ -163,18 +163,25 @@ fma(InType x, InType y, InType z) {
   int y_exp = 0;
   int z_exp = 0;
 
+  // Denormal scaling = 2^(fraction length).
+  constexpr InType DENORMAL_SCALING =
+      InFPBits(static_cast<InStorageType>(InFPBits::FRACTION_LEN +
+                                          InFPBits::EXP_BIAS)
+               << InFPBits::SIG_LEN)
+          .get_val();
+
   // Normalize denormal inputs.
   if (LIBC_UNLIKELY(InFPBits(x).is_subnormal())) {
     x_exp -= InFPBits::FRACTION_LEN;
-    x *= InType(InStorageType(1) << InFPBits::FRACTION_LEN);
+    x *= DENORMAL_SCALING;
   }
   if (LIBC_UNLIKELY(InFPBits(y).is_subnormal())) {
     y_exp -= InFPBits::FRACTION_LEN;
-    y *= InType(InStorageType(1) << InFPBits::FRACTION_LEN);
+    y *= DENORMAL_SCALING;
   }
   if (LIBC_UNLIKELY(InFPBits(z).is_subnormal())) {
     z_exp -= InFPBits::FRACTION_LEN;
-    z *= InType(InStorageType(1) << InFPBits::FRACTION_LEN);
+    z *= DENORMAL_SCALING;
   }
 
   x_bits = InFPBits(x);
diff --git a/libc/src/__support/FPUtil/generic/add_sub.h b/libc/src/__support/FPUtil/generic/add_sub.h
index ec20a8723b704..850db3f83209e 100644
--- a/libc/src/__support/FPUtil/generic/add_sub.h
+++ b/libc/src/__support/FPUtil/generic/add_sub.h
@@ -56,19 +56,19 @@ add_or_sub(InType x, InType y) {
         raise_except_if_required(FE_INVALID);
 
       if (x_bits.is_quiet_nan()) {
-        InStorageType x_payload = static_cast<InStorageType>(getpayload(x));
-        if ((x_payload & ~(OutFPBits::FRACTION_MASK >> 1)) == 0)
-          return OutFPBits::quiet_nan(x_bits.sign(),
-                                      static_cast<OutStorageType>(x_payload))
-              .get_val();
+        InStorageType x_payload = x_bits.get_mantissa();
+        x_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
+        return OutFPBits::quiet_nan(x_bits.sign(),
+                                    static_cast<OutStorageType>(x_payload))
+            .get_val();
       }
 
       if (y_bits.is_quiet_nan()) {
-        InStorageType y_payload = static_cast<InStorageType>(getpayload(y));
-        if ((y_payload & ~(OutFPBits::FRACTION_MASK >> 1)) == 0)
-          return OutFPBits::quiet_nan(y_bits.sign(),
-                                      static_cast<OutStorageType>(y_payload))
-              .get_val();
+        InStorageType y_payload = y_bits.get_mantissa();
+        y_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
+        return OutFPBits::quiet_nan(y_bits.sign(),
+                                    static_cast<OutStorageType>(y_payload))
+            .get_val();
       }
 
       return OutFPBits::quiet_nan().get_val();
@@ -174,10 +174,12 @@ add_or_sub(InType x, InType y) {
     else
       aligned_min_mant_sticky = true;
 
+    InStorageType min_mant_sticky(static_cast<int>(aligned_min_mant_sticky));
+
     if (is_effectively_add)
-      result_mant = max_mant + (aligned_min_mant | aligned_min_mant_sticky);
+      result_mant = max_mant + (aligned_min_mant | min_mant_sticky);
     else
-      result_mant = max_mant - (aligned_min_mant | aligned_min_mant_sticky);
+      result_mant = max_mant - (aligned_min_mant | min_mant_sticky);
   }
 
   int result_exp = max_bits.get_exponent() - RESULT_FRACTION_LEN;
diff --git a/libc/src/__support/FPUtil/generic/div.h b/libc/src/__support/FPUtil/generic/div.h
index 27545786aea17..dad1772fce750 100644
--- a/libc/src/__support/FPUtil/generic/div.h
+++ b/libc/src/__support/FPUtil/generic/div.h
@@ -49,19 +49,19 @@ div(InType x, InType y) {
         raise_except_if_required(FE_INVALID);
 
       if (x_bits.is_quiet_nan()) {
-        InStorageType x_payload = static_cast<InStorageType>(getpayload(x));
-        if ((x_payload & ~(OutFPBits::FRACTION_MASK >> 1)) == 0)
-          return OutFPBits::quiet_nan(x_bits.sign(),
-                                      static_cast<OutStorageType>(x_payload))
-              .get_val();
+        InStorageType x_payload = x_bits.get_mantissa();
+        x_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
+        return OutFPBits::quiet_nan(x_bits.sign(),
+                                    static_cast<OutStorageType>(x_payload))
+            .get_val();
       }
 
       if (y_bits.is_quiet_nan()) {
-        InStorageType y_payload = static_cast<InStorageType>(getpayload(y));
-        if ((y_payload & ~(OutFPBits::FRACTION_MASK >> 1)) == 0)
-          return OutFPBits::quiet_nan(y_bits.sign(),
-                                      static_cast<OutStorageType>(y_payload))
-              .get_val();
+        InStorageType y_payload = y_bits.get_mantissa();
+        y_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
+        return OutFPBits::quiet_nan(y_bits.sign(),
+                                    static_cast<OutStorageType>(y_payload))
+            .get_val();
       }
 
       return OutFPBits::quiet_nan().get_val();
diff --git a/libc/src/__support/FPUtil/generic/mul.h b/libc/src/__support/FPUtil/generic/mul.h
index 61be3719e123a..20d9a77792762 100644
--- a/libc/src/__support/FPUtil/generic/mul.h
+++ b/libc/src/__support/FPUtil/generic/mul.h
@@ -51,7 +51,7 @@ mul(InType x, InType y) {
 
       if (x_bits.is_quiet_nan()) {
         InStorageType x_payload = x_bits.get_mantissa();
-        x_payload >>= (InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN);
+        x_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
         return OutFPBits::quiet_nan(x_bits.sign(),
                                     static_cast<OutStorageType>(x_payload))
             .get_val();
@@ -59,7 +59,7 @@ mul(InType x, InType y) {
 
       if (y_bits.is_quiet_nan()) {
         InStorageType y_payload = y_bits.get_mantissa();
-        y_payload >>= (InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN);
+        y_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
         return OutFPBits::quiet_nan(y_bits.sign(),
                                     static_cast<OutStorageType>(y_payload))
             .get_val();



More information about the libc-commits mailing list