[libc-commits] [libc] [libc] Add mpfr tests (PR #97376)

via libc-commits libc-commits at lists.llvm.org
Mon Jul 1 19:32:08 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-libc

Author: Job Henandez Lara (Jobhdez)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/97376.diff


5 Files Affected:

- (modified) libc/test/src/math/CMakeLists.txt (+12) 
- (added) libc/test/src/math/FMulTest.h (+121) 
- (added) libc/test/src/math/fmul_test.cpp (+13) 
- (modified) libc/utils/MPFRWrapper/MPFRUtils.cpp (+19-42) 
- (modified) libc/utils/MPFRWrapper/MPFRUtils.h (+28-12) 


``````````diff
diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt
index c07c6d77fa233..9eda5db1ea2fc 100644
--- a/libc/test/src/math/CMakeLists.txt
+++ b/libc/test/src/math/CMakeLists.txt
@@ -1823,6 +1823,18 @@ add_fp_unittest(
     libc.src.__support.FPUtil.fp_bits
 )
 
+add_fp_unittest(
+  fmul_test
+  NEED_MPFR
+  SUITE
+    libc-math-unittests
+  SRCS
+    fmul_test.cpp
+  HDRS
+    FMulTest.h
+  DEPENDS
+    libc.src.math.fmul
+)
 add_fp_unittest(
   asinhf_test
   NEED_MPFR
diff --git a/libc/test/src/math/FMulTest.h b/libc/test/src/math/FMulTest.h
new file mode 100644
index 0000000000000..864910c29d83f
--- /dev/null
+++ b/libc/test/src/math/FMulTest.h
@@ -0,0 +1,121 @@
+//===-- Utility class to test fmul[f|l] ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_TEST_SRC_MATH_FMULTEST_H
+#define LLVM_LIBC_TEST_SRC_MATH_FMULTEST_H
+
+#include "src/__support/FPUtil/FPBits.h"
+#include "test/UnitTest/FEnvSafeTest.h"
+#include "test/UnitTest/FPMatcher.h"
+#include "test/UnitTest/Test.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
+
+template <typename OutType, typename InType>
+class FmulMPFRTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
+
+  DECLARE_SPECIAL_CONSTANTS(InType)
+
+public:
+  typedef OutType (*FMulFunc)(InType, InType);
+
+  void testFMulMPFR(FMulFunc func) {
+    constexpr int N = 10;
+    mpfr::BinaryInput<InType> INPUTS[N] = {
+        {3.0, 5.0},
+        {0x1.0p1, 0x1.0p-131},
+        {0x1.0p2, 0x1.0p-129},
+        {1.0, 1.0},
+        {-0.0, -0.0},
+        {-0.0, 0.0},
+        {0.0, -0.0},
+        {0x1.0p100, 0x1.0p100},
+        {1.0, 1.0 + 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150},
+        {1.0, 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150}};
+
+    for (int i = 0; i < N; ++i) {
+      InType x = INPUTS[i].x;
+      InType y = INPUTS[i].y;
+      ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fmul, INPUTS[i],
+                                     func(x, y), 0.5);
+    }
+  }
+
+  void testSpecialInputsMPFR(FMulFunc func) {
+    constexpr int N = 27;
+    mpfr::BinaryInput<InType> INPUTS[N] = {{inf, 0x1.0p-129},
+                                           {0x1.0p-129, inf},
+                                           {inf, 2.0},
+                                           {3.0, inf},
+                                           {0.0, 0.0},
+                                           {neg_inf, aNaN},
+                                           {aNaN, neg_inf},
+                                           {neg_inf, neg_inf},
+                                           {0.0, neg_inf},
+                                           {neg_inf, 0.0},
+                                           {neg_inf, 1.0},
+                                           {1.0, neg_inf},
+                                           {neg_inf, 0x1.0p-129},
+                                           {0x1.0p-129, neg_inf},
+                                           {0.0, 0x1.0p-129},
+                                           {inf, 0.0},
+                                           {0.0, inf},
+                                           {0.0, aNaN},
+                                           {2.0, aNaN},
+                                           {0x1.0p-129, aNaN},
+                                           {inf, aNaN},
+                                           {aNaN, aNaN},
+                                           {0.0, sNaN},
+                                           {2.0, sNaN},
+                                           {0x1.0p-129, sNaN},
+                                           {inf, sNaN},
+                                           {sNaN, sNaN}};
+
+    for (int i = 0; i < N; ++i) {
+      InType x = INPUTS[i].x;
+      InType y = INPUTS[i].y;
+      ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fmul, INPUTS[i],
+                                     func(x, y), 0.5);
+    }
+  }
+
+  void testNormalRange(FMulFunc func) {
+    using FPBits = LIBC_NAMESPACE::fputil::FPBits<InType>;
+    using StorageType = typename FPBits::StorageType;
+    static constexpr StorageType MAX_NORMAL = FPBits::max_normal().uintval();
+    static constexpr StorageType MIN_NORMAL = FPBits::min_normal().uintval();
+
+    constexpr StorageType COUNT = 10'001;
+    constexpr StorageType STEP = (MAX_NORMAL - MIN_NORMAL) / COUNT;
+    for (int signs = 0; signs < 4; ++signs) {
+      for (StorageType v = MIN_NORMAL, w = MAX_NORMAL;
+           v <= MAX_NORMAL && w >= MIN_NORMAL; v += STEP, w -= STEP) {
+        InType x = FPBits(v).get_val(), y = FPBits(w).get_val();
+        if (signs % 2 == 1) {
+          x = -x;
+        }
+        if (signs >= 2) {
+          y = -y;
+        }
+
+        mpfr::BinaryInput<InType> input{x, y};
+        ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fmul, input, func(x, y),
+                                       0.5);
+      }
+    }
+  }
+};
+
+#define LIST_FMUL_MPFR_TESTS(OutType, InType, func)                            \
+  using LlvmLibcFmulTest = FmulMPFRTest<OutType, InType>;                      \
+  TEST_F(LlvmLibcFmulTest, MulMpfr) { testFMulMPFR(&func); }                   \
+  TEST_F(LlvmLibcFmulTest, NanInfMpfr) { testSpecialInputsMPFR(&func); }       \
+  TEST_F(LlvmLibcFmulTest, NormalRange) { testNormalRange(&func); }
+
+#endif // LLVM_LIBC_TEST_SRC_MATH_FMULTEST_H
diff --git a/libc/test/src/math/fmul_test.cpp b/libc/test/src/math/fmul_test.cpp
new file mode 100644
index 0000000000000..16eaa1a818daf
--- /dev/null
+++ b/libc/test/src/math/fmul_test.cpp
@@ -0,0 +1,13 @@
+//===-- Unittests for fmul-------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#include "FMulTest.h"
+
+#include "src/math/fmul.h"
+
+LIST_FMUL_MPFR_TESTS(float, double, LIBC_NAMESPACE::fmul)
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp
index 379a631a356a3..cf4ded5fb66b4 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp
@@ -487,6 +487,12 @@ class MPFRNumber {
     return result;
   }
 
+  MPFRNumber fmul(const MPFRNumber &b) {
+    MPFRNumber result(*this);
+    mpfr_mul(result.value, value, b.value, mpfr_rounding);
+    return result;
+  }
+
   cpp::string str() const {
     // 200 bytes should be more than sufficient to hold a 100-digit number
     // plus additional bytes for the decimal point, '-' sign etc.
@@ -738,6 +744,8 @@ binary_operation_one_output(Operation op, InputType x, InputType y,
     return inputX.hypot(inputY);
   case Operation::Pow:
     return inputX.pow(inputY);
+  case Operation::Fmul:
+    return inputX.fmul(inputY);
   default:
     __builtin_unreachable();
   }
@@ -947,21 +955,9 @@ explain_binary_operation_one_output_error(Operation, const BinaryInput<float> &,
                                           float, double, RoundingMode);
 template void explain_binary_operation_one_output_error(
     Operation, const BinaryInput<double> &, double, double, RoundingMode);
-template void
-explain_binary_operation_one_output_error(Operation,
-                                          const BinaryInput<long double> &,
-                                          long double, double, RoundingMode);
-#ifdef LIBC_TYPES_HAS_FLOAT16
-template void explain_binary_operation_one_output_error(
-    Operation, const BinaryInput<float16> &, float16, double, RoundingMode);
-template void
-explain_binary_operation_one_output_error(Operation, const BinaryInput<float> &,
-                                          float16, double, RoundingMode);
-template void explain_binary_operation_one_output_error(
-    Operation, const BinaryInput<double> &, float16, double, RoundingMode);
-template void explain_binary_operation_one_output_error(
-    Operation, const BinaryInput<long double> &, float16, double, RoundingMode);
-#endif
+template void explain_binary_operation_one_output_error<long double>(
+    Operation, const BinaryInput<long double> &, long double, double,
+    RoundingMode);
 
 template <typename InputType, typename OutputType>
 void explain_ternary_operation_one_output_error(
@@ -1109,7 +1105,7 @@ bool compare_binary_operation_one_output(Operation op,
                                          OutputType libc_result,
                                          double ulp_tolerance,
                                          RoundingMode rounding) {
-  unsigned int precision = get_precision<InputType>(ulp_tolerance);
+  unsigned int precision = get_precision<T>(ulp_tolerance);
   MPFRNumber mpfr_result =
       binary_operation_one_output(op, input.x, input.y, precision, rounding);
   double ulp = mpfr_result.ulp(libc_result);
@@ -1117,32 +1113,13 @@ bool compare_binary_operation_one_output(Operation op,
   return (ulp <= ulp_tolerance);
 }
 
-template bool compare_binary_operation_one_output(Operation,
-                                                  const BinaryInput<float> &,
-                                                  float, double, RoundingMode);
-template bool compare_binary_operation_one_output(Operation,
-                                                  const BinaryInput<double> &,
-                                                  double, double, RoundingMode);
-template bool
-compare_binary_operation_one_output(Operation, const BinaryInput<long double> &,
-                                    long double, double, RoundingMode);
-#ifdef LIBC_TYPES_HAS_FLOAT16
-template bool compare_binary_operation_one_output(Operation,
-                                                  const BinaryInput<float16> &,
-                                                  float16, double,
-                                                  RoundingMode);
-template bool compare_binary_operation_one_output(Operation,
-                                                  const BinaryInput<float> &,
-                                                  float16, double,
-                                                  RoundingMode);
-template bool compare_binary_operation_one_output(Operation,
-                                                  const BinaryInput<double> &,
-                                                  float16, double,
-                                                  RoundingMode);
-template bool
-compare_binary_operation_one_output(Operation, const BinaryInput<long double> &,
-                                    float16, double, RoundingMode);
-#endif
+template bool compare_binary_operation_one_output<float>(
+    Operation, const BinaryInput<float> &, float, double, RoundingMode);
+template bool compare_binary_operation_one_output<double>(
+    Operation, const BinaryInput<double> &, double, double, RoundingMode);
+template bool compare_binary_operation_one_output<long double>(
+    Operation, const BinaryInput<long double> &, long double, double,
+    RoundingMode);
 
 template <typename InputType, typename OutputType>
 bool compare_ternary_operation_one_output(Operation op,
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h
index 11e323bf6881d..7621866e6d730 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.h
+++ b/libc/utils/MPFRWrapper/MPFRUtils.h
@@ -41,6 +41,7 @@ enum class Operation : int {
   Exp10,
   Expm1,
   Floor,
+  Fmul,
   Log,
   Log2,
   Log10,
@@ -147,6 +148,14 @@ template <typename T> struct IsTernaryInput<TernaryInput<T>> {
   static constexpr bool VALUE = true;
 };
 
+template <typename T> struct IsBinaryInput {
+  static constexpr bool VALUE = false;
+};
+
+template <typename T> struct IsBinaryInput<BinaryInput<T>> {
+  static constexpr bool VALUE = true;
+};
+
 template <typename T> struct MakeScalarInput : cpp::type_identity<T> {};
 
 template <typename T>
@@ -237,12 +246,14 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
   bool is_silent() const override { return silent; }
 
 private:
-  template <typename T, typename U> bool match(T in, U out) {
+  template <typename InType, typename OutType>
+  bool match(InType in, OutType out) {
     return compare_unary_operation_single_output(op, in, out, ulp_tolerance,
                                                  rounding);
   }
 
-  template <typename T> bool match(T in, const BinaryOutput<T> &out) {
+  template <typename InType>
+  bool match(InType in, const BinaryOutput<InType> &out) {
     return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance,
                                                rounding);
   }
@@ -253,30 +264,33 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
                                                rounding);
   }
 
-  template <typename T>
-  bool match(BinaryInput<T> in, const BinaryOutput<T> &out) {
+  template <typename InType>
+  bool match(BinaryInput<InType> in, const BinaryOutput<InType> &out) {
     return compare_binary_operation_two_outputs(op, in, out, ulp_tolerance,
                                                 rounding);
   }
 
-  template <typename T, typename U>
-  bool match(const TernaryInput<T> &in, U out) {
+  template <typename InType, typename OutType>
+  bool match(const TernaryInput<InType> &in, OutType out) {
     return compare_ternary_operation_one_output(op, in, out, ulp_tolerance,
                                                 rounding);
   }
 
-  template <typename T, typename U> void explain_error(T in, U out) {
+  template <typename InType, typename OutType>
+  void explain_error(InType in, OutType out) {
     explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
                                                 rounding);
   }
 
-  template <typename T> void explain_error(T in, const BinaryOutput<T> &out) {
+  template <typename InType>
+  void explain_error(InType in, const BinaryOutput<InType> &out) {
     explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance,
                                               rounding);
   }
 
-  template <typename T>
-  void explain_error(const BinaryInput<T> &in, const BinaryOutput<T> &out) {
+  template <typename InType>
+  void explain_error(const BinaryInput<InType> &in,
+                     const BinaryOutput<InType> &out) {
     explain_binary_operation_two_outputs_error(op, in, out, ulp_tolerance,
                                                rounding);
   }
@@ -287,8 +301,8 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
                                               rounding);
   }
 
-  template <typename T, typename U>
-  void explain_error(const TernaryInput<T> &in, U out) {
+  template <typename InType, typename OutType>
+  void explain_error(const TernaryInput<InType> &in, OutType out) {
     explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance,
                                                rounding);
   }
@@ -311,6 +325,8 @@ constexpr bool is_valid_operation() {
       (op == Operation::Fma && internal::IsTernaryInput<InputType>::VALUE &&
        cpp::is_floating_point_v<
            typename internal::MakeScalarInput<InputType>::type> &&
+       cpp::is_floating_point_v<OutputType>) ||
+      (op == Operation::Fmul && internal::IsBinaryInput<InputType>::VALUE &&
        cpp::is_floating_point_v<OutputType>);
   if (IS_NARROWING_OP)
     return true;

``````````

</details>


https://github.com/llvm/llvm-project/pull/97376


More information about the libc-commits mailing list