[libc-commits] [libc] [libc][math][c23] Add MPFR exhaustive test for fmodf16 (PR #94656)

via libc-commits libc-commits at lists.llvm.org
Mon Jun 24 14:03:58 PDT 2024


https://github.com/overmighty updated https://github.com/llvm/llvm-project/pull/94656

>From b3f57e982821f13897259503d929ac358e262b30 Mon Sep 17 00:00:00 2001
From: OverMighty <its.overmighty at gmail.com>
Date: Mon, 24 Jun 2024 21:25:43 +0200
Subject: [PATCH 1/3] [libc] Restore libc/utils/MPFRWrapper from branch
 overmighty:libc-math-f16divf

---
 libc/utils/MPFRWrapper/MPFRUtils.cpp | 84 ++++++++++++++++++----------
 libc/utils/MPFRWrapper/MPFRUtils.h   | 39 +++++++++----
 2 files changed, 81 insertions(+), 42 deletions(-)

diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp
index 2eac4dd8e199d..521c2658b327a 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp
@@ -296,6 +296,12 @@ class MPFRNumber {
     return result;
   }
 
+  MPFRNumber div(const MPFRNumber &b) const {
+    MPFRNumber result(*this);
+    mpfr_div(result.value, value, b.value, mpfr_rounding);
+    return result;
+  }
+
   MPFRNumber floor() const {
     MPFRNumber result(*this);
     mpfr_floor(result.value, value);
@@ -708,6 +714,8 @@ binary_operation_one_output(Operation op, InputType x, InputType y,
   switch (op) {
   case Operation::Atan2:
     return inputX.atan2(inputY);
+  case Operation::Div:
+    return inputX.div(inputY);
   case Operation::Fmod:
     return inputX.fmod(inputY);
   case Operation::Hypot:
@@ -885,42 +893,47 @@ template void explain_binary_operation_two_outputs_error<long double>(
     Operation, const BinaryInput<long double> &,
     const BinaryOutput<long double> &, double, RoundingMode);
 
-template <typename T>
-void explain_binary_operation_one_output_error(Operation op,
-                                               const BinaryInput<T> &input,
-                                               T libc_result,
-                                               double ulp_tolerance,
-                                               RoundingMode rounding) {
-  unsigned int precision = get_precision<T>(ulp_tolerance);
+template <typename InputType, typename OutputType>
+void explain_binary_operation_one_output_error(
+    Operation op, const BinaryInput<InputType> &input, OutputType libc_result,
+    double ulp_tolerance, RoundingMode rounding) {
+  unsigned int precision = get_precision<InputType>(ulp_tolerance);
   MPFRNumber mpfrX(input.x, precision);
   MPFRNumber mpfrY(input.y, precision);
-  FPBits<T> xbits(input.x);
-  FPBits<T> ybits(input.y);
+  FPBits<InputType> xbits(input.x);
+  FPBits<InputType> ybits(input.y);
   MPFRNumber mpfr_result =
       binary_operation_one_output(op, input.x, input.y, precision, rounding);
   MPFRNumber mpfrMatchValue(libc_result);
 
   tlog << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n';
-  tlog << "First input bits: " << str(FPBits<T>(input.x)) << '\n';
-  tlog << "Second input bits: " << str(FPBits<T>(input.y)) << '\n';
+  tlog << "First input bits: " << str(FPBits<InputType>(input.x)) << '\n';
+  tlog << "Second input bits: " << str(FPBits<InputType>(input.y)) << '\n';
 
   tlog << "Libc result: " << mpfrMatchValue.str() << '\n'
        << "MPFR result: " << mpfr_result.str() << '\n';
-  tlog << "Libc floating point result bits: " << str(FPBits<T>(libc_result))
-       << '\n';
+  tlog << "Libc floating point result bits: "
+       << str(FPBits<OutputType>(libc_result)) << '\n';
   tlog << "              MPFR rounded bits: "
-       << str(FPBits<T>(mpfr_result.as<T>())) << '\n';
+       << str(FPBits<OutputType>(mpfr_result.as<OutputType>())) << '\n';
   tlog << "ULP error: " << mpfr_result.ulp_as_mpfr_number(libc_result).str()
        << '\n';
 }
 
-template void explain_binary_operation_one_output_error<float>(
-    Operation, const BinaryInput<float> &, float, double, RoundingMode);
-template void explain_binary_operation_one_output_error<double>(
+template void
+explain_binary_operation_one_output_error(Operation, const BinaryInput<float> &,
+                                          float, double, RoundingMode);
+template void explain_binary_operation_one_output_error(
     Operation, const BinaryInput<double> &, double, double, RoundingMode);
-template void explain_binary_operation_one_output_error<long double>(
-    Operation, const BinaryInput<long double> &, long double, double,
-    RoundingMode);
+template void
+explain_binary_operation_one_output_error(Operation,
+                                          const BinaryInput<long double> &,
+                                          long double, double, RoundingMode);
+#ifdef LIBC_TYPES_HAS_FLOAT16
+template void
+explain_binary_operation_one_output_error(Operation, const BinaryInput<float> &,
+                                          float16, double, RoundingMode);
+#endif
 
 template <typename InputType, typename OutputType>
 void explain_ternary_operation_one_output_error(
@@ -1051,12 +1064,13 @@ template bool compare_binary_operation_two_outputs<long double>(
     Operation, const BinaryInput<long double> &,
     const BinaryOutput<long double> &, double, RoundingMode);
 
-template <typename T>
+template <typename InputType, typename OutputType>
 bool compare_binary_operation_one_output(Operation op,
-                                         const BinaryInput<T> &input,
-                                         T libc_result, double ulp_tolerance,
+                                         const BinaryInput<InputType> &input,
+                                         OutputType libc_result,
+                                         double ulp_tolerance,
                                          RoundingMode rounding) {
-  unsigned int precision = get_precision<T>(ulp_tolerance);
+  unsigned int precision = get_precision<InputType>(ulp_tolerance);
   MPFRNumber mpfr_result =
       binary_operation_one_output(op, input.x, input.y, precision, rounding);
   double ulp = mpfr_result.ulp(libc_result);
@@ -1064,13 +1078,21 @@ bool compare_binary_operation_one_output(Operation op,
   return (ulp <= ulp_tolerance);
 }
 
-template bool compare_binary_operation_one_output<float>(
-    Operation, const BinaryInput<float> &, float, double, RoundingMode);
-template bool compare_binary_operation_one_output<double>(
-    Operation, const BinaryInput<double> &, double, double, RoundingMode);
-template bool compare_binary_operation_one_output<long double>(
-    Operation, const BinaryInput<long double> &, long double, double,
-    RoundingMode);
+template bool compare_binary_operation_one_output(Operation,
+                                                  const BinaryInput<float> &,
+                                                  float, double, RoundingMode);
+template bool compare_binary_operation_one_output(Operation,
+                                                  const BinaryInput<double> &,
+                                                  double, double, RoundingMode);
+template bool
+compare_binary_operation_one_output(Operation, const BinaryInput<long double> &,
+                                    long double, double, RoundingMode);
+#ifdef LIBC_TYPES_HAS_FLOAT16
+template bool compare_binary_operation_one_output(Operation,
+                                                  const BinaryInput<float> &,
+                                                  float16, double,
+                                                  RoundingMode);
+#endif
 
 template <typename InputType, typename OutputType>
 bool compare_ternary_operation_one_output(Operation op,
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h
index 0b4f42a72ec81..46f3375fd4b7e 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.h
+++ b/libc/utils/MPFRWrapper/MPFRUtils.h
@@ -71,6 +71,7 @@ enum class Operation : int {
   // output.
   BeginBinaryOperationsSingleOutput,
   Atan2,
+  Div,
   Fmod,
   Hypot,
   Pow,
@@ -129,6 +130,14 @@ struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
   static constexpr bool VALUE = cpp::is_floating_point_v<T>;
 };
 
+template <typename T> struct IsBinaryInput {
+  static constexpr bool VALUE = false;
+};
+
+template <typename T> struct IsBinaryInput<BinaryInput<T>> {
+  static constexpr bool VALUE = true;
+};
+
 template <typename T> struct IsTernaryInput {
   static constexpr bool VALUE = false;
 };
@@ -139,6 +148,9 @@ template <typename T> struct IsTernaryInput<TernaryInput<T>> {
 
 template <typename T> struct MakeScalarInput : cpp::type_identity<T> {};
 
+template <typename T>
+struct MakeScalarInput<BinaryInput<T>> : cpp::type_identity<T> {};
+
 template <typename T>
 struct MakeScalarInput<TernaryInput<T>> : cpp::type_identity<T> {};
 
@@ -159,10 +171,11 @@ bool compare_binary_operation_two_outputs(Operation op,
                                           double ulp_tolerance,
                                           RoundingMode rounding);
 
-template <typename T>
+template <typename InputType, typename OutputType>
 bool compare_binary_operation_one_output(Operation op,
-                                         const BinaryInput<T> &input,
-                                         T libc_output, double ulp_tolerance,
+                                         const BinaryInput<InputType> &input,
+                                         OutputType libc_output,
+                                         double ulp_tolerance,
                                          RoundingMode rounding);
 
 template <typename InputType, typename OutputType>
@@ -187,12 +200,10 @@ void explain_binary_operation_two_outputs_error(
     const BinaryOutput<T> &match_value, double ulp_tolerance,
     RoundingMode rounding);
 
-template <typename T>
-void explain_binary_operation_one_output_error(Operation op,
-                                               const BinaryInput<T> &input,
-                                               T match_value,
-                                               double ulp_tolerance,
-                                               RoundingMode rounding);
+template <typename InputType, typename OutputType>
+void explain_binary_operation_one_output_error(
+    Operation op, const BinaryInput<InputType> &input, OutputType match_value,
+    double ulp_tolerance, RoundingMode rounding);
 
 template <typename InputType, typename OutputType>
 void explain_ternary_operation_one_output_error(
@@ -235,7 +246,8 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
                                                rounding);
   }
 
-  template <typename T> bool match(const BinaryInput<T> &in, T out) {
+  template <typename T, typename U>
+  bool match(const BinaryInput<T> &in, U out) {
     return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
                                                rounding);
   }
@@ -268,7 +280,8 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
                                                rounding);
   }
 
-  template <typename T> void explain_error(const BinaryInput<T> &in, T out) {
+  template <typename T, typename U>
+  void explain_error(const BinaryInput<T> &in, U out) {
     explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
                                               rounding);
   }
@@ -290,6 +303,10 @@ constexpr bool is_valid_operation() {
       (op == Operation::Sqrt && cpp::is_floating_point_v<InputType> &&
        cpp::is_floating_point_v<OutputType> &&
        sizeof(OutputType) <= sizeof(InputType)) ||
+      (op == Operation::Div && internal::IsBinaryInput<InputType>::VALUE &&
+       cpp::is_floating_point_v<
+           typename internal::MakeScalarInput<InputType>::type> &&
+       cpp::is_floating_point_v<OutputType>) ||
       (op == Operation::Fma && internal::IsTernaryInput<InputType>::VALUE &&
        cpp::is_floating_point_v<
            typename internal::MakeScalarInput<InputType>::type> &&

>From 5d66a52f315ac964f3cc5440a44b2a6f4f5ba908 Mon Sep 17 00:00:00 2001
From: OverMighty <its.overmighty at gmail.com>
Date: Thu, 6 Jun 2024 20:29:46 +0200
Subject: [PATCH 2/3] [libc][math][c23] Add MPFR exhaustive test for fmodf16

---
 libc/test/src/math/exhaustive/CMakeLists.txt  |  15 ++
 .../src/math/exhaustive/exhaustive_test.h     | 150 ++++++++++++++++++
 .../test/src/math/exhaustive/fmodf16_test.cpp |  41 +++++
 libc/utils/MPFRWrapper/MPFRUtils.cpp          |   6 +
 4 files changed, 212 insertions(+)
 create mode 100644 libc/test/src/math/exhaustive/fmodf16_test.cpp

diff --git a/libc/test/src/math/exhaustive/CMakeLists.txt b/libc/test/src/math/exhaustive/CMakeLists.txt
index 34df8720ed4db..422934d2f51ff 100644
--- a/libc/test/src/math/exhaustive/CMakeLists.txt
+++ b/libc/test/src/math/exhaustive/CMakeLists.txt
@@ -277,6 +277,21 @@ add_fp_unittest(
     libc.src.__support.FPUtil.generic.fmod
 )
 
+add_fp_unittest(
+  fmodf16_test
+  NO_RUN_POSTBUILD
+  NEED_MPFR
+  SUITE
+    libc_math_exhaustive_tests
+  SRCS
+    fmodf16_test.cpp
+  DEPENDS
+    .exhaustive_test
+    libc.src.math.fmodf16
+  LINK_LIBRARIES
+    -lpthread
+)
+
 add_fp_unittest(
   coshf_test
   NO_RUN_POSTBUILD
diff --git a/libc/test/src/math/exhaustive/exhaustive_test.h b/libc/test/src/math/exhaustive/exhaustive_test.h
index 13e272783250b..9afbdabb7fd1f 100644
--- a/libc/test/src/math/exhaustive/exhaustive_test.h
+++ b/libc/test/src/math/exhaustive/exhaustive_test.h
@@ -8,6 +8,7 @@
 
 #include "src/__support/CPP/type_traits.h"
 #include "src/__support/FPUtil/FPBits.h"
+#include "src/__support/macros/properties/types.h"
 #include "test/UnitTest/FPMatcher.h"
 #include "test/UnitTest/Test.h"
 #include "utils/MPFRWrapper/MPFRUtils.h"
@@ -68,6 +69,43 @@ struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
   }
 };
 
+template <typename T> using BinaryOp = T(T, T);
+
+template <typename T, mpfr::Operation Op, BinaryOp<T> Func>
+struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
+  using FloatType = T;
+  using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
+  using StorageType = typename FPBits::StorageType;
+
+  static constexpr BinaryOp<FloatType> *FUNC = Func;
+
+  // Check in a range, return the number of failures.
+  uint64_t check(StorageType x_start, StorageType x_stop, StorageType y_start,
+                 StorageType y_stop, mpfr::RoundingMode rounding) {
+    mpfr::ForceRoundingMode r(rounding);
+    if (!r.success)
+      return (x_stop > x_start || y_stop > y_start);
+    StorageType xbits = x_start;
+    uint64_t failed = 0;
+    do {
+      FloatType x = FPBits(xbits).get_val();
+      StorageType ybits = y_start;
+      do {
+        FloatType y = FPBits(ybits).get_val();
+        mpfr::BinaryInput<FloatType> input{x, y};
+        bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, input, FUNC(x, y),
+                                                         0.5, rounding);
+        failed += (!correct);
+        // Uncomment to print out failed values.
+        // if (!correct) {
+        //   TEST_MPFR_MATCH(Op::Operation, x, Op::func(x, y), 0.5, rounding);
+        // }
+      } while (ybits++ < y_stop);
+    } while (xbits++ < x_stop);
+    return failed;
+  }
+};
+
 // Checker class needs inherit from LIBC_NAMESPACE::testing::Test and provide
 //   StorageType and check method.
 template <typename Checker>
@@ -167,6 +205,114 @@ struct LlvmLibcExhaustiveMathTest
   };
 };
 
+template <typename Checker>
+struct LlvmLibcBinaryInputExhaustiveMathTest
+    : public virtual LIBC_NAMESPACE::testing::Test,
+      public Checker {
+  using FloatType = typename Checker::FloatType;
+  using FPBits = typename Checker::FPBits;
+  using StorageType = typename Checker::StorageType;
+
+  static constexpr StorageType Increment = (1 << 2);
+
+  // Break [start, stop) into `nthreads` subintervals and apply *check to each
+  // subinterval in parallel.
+  void test_full_range(StorageType x_start, StorageType x_stop,
+                       StorageType y_start, StorageType y_stop,
+                       mpfr::RoundingMode rounding) {
+    int n_threads = std::thread::hardware_concurrency();
+    std::vector<std::thread> thread_list;
+    std::mutex mx_cur_val;
+    int current_percent = -1;
+    StorageType current_value = x_start;
+    std::atomic<uint64_t> failed(0);
+
+    for (int i = 0; i < n_threads; ++i) {
+      thread_list.emplace_back([&, this]() {
+        while (true) {
+          StorageType range_begin, range_end;
+          int new_percent = -1;
+          {
+            std::lock_guard<std::mutex> lock(mx_cur_val);
+            if (current_value == x_stop)
+              return;
+
+            range_begin = current_value;
+            if (x_stop >= Increment && x_stop - Increment >= current_value) {
+              range_end = current_value + Increment;
+            } else {
+              range_end = x_stop;
+            }
+            current_value = range_end;
+            int pc = 100.0 * (range_end - x_start) / (x_stop - x_start);
+            if (current_percent != pc) {
+              new_percent = pc;
+              current_percent = pc;
+            }
+          }
+          if (new_percent >= 0) {
+            std::stringstream msg;
+            msg << new_percent << "% is in process     \r";
+            std::cout << msg.str() << std::flush;
+          }
+
+          uint64_t failed_in_range =
+              Checker::check(range_begin, range_end, y_start, y_stop, rounding);
+          if (failed_in_range > 0) {
+            using T = LIBC_NAMESPACE::cpp::conditional_t<
+                LIBC_NAMESPACE::cpp::is_same_v<FloatType, float16>, float,
+                FloatType>;
+            std::stringstream msg;
+            msg << "Test failed for " << std::dec << failed_in_range
+                << " inputs in range: " << range_begin << " to " << range_end
+                << " [0x" << std::hex << range_begin << ", 0x" << range_end
+                << "), [" << std::hexfloat
+                << static_cast<T>(FPBits(range_begin).get_val()) << ", "
+                << static_cast<T>(FPBits(range_end).get_val()) << ")\n";
+            std::cerr << msg.str() << std::flush;
+
+            failed.fetch_add(failed_in_range);
+          }
+        }
+      });
+    }
+
+    for (auto &thread : thread_list) {
+      if (thread.joinable()) {
+        thread.join();
+      }
+    }
+
+    std::cout << std::endl;
+    std::cout << "Test " << ((failed > 0) ? "FAILED" : "PASSED") << std::endl;
+    ASSERT_EQ(failed.load(), uint64_t(0));
+  }
+
+  void test_full_range_all_roundings(StorageType x_start, StorageType x_stop,
+                                     StorageType y_start, StorageType y_stop) {
+    test_full_range(x_start, x_stop, y_start, y_stop,
+                    mpfr::RoundingMode::Nearest);
+
+    std::cout << "-- Testing for FE_UPWARD in x range [0x" << std::hex
+              << x_start << ", 0x" << x_stop << ") y range [0x" << std::hex
+              << y_start << ", 0x" << y_stop << ") --" << std::dec << std::endl;
+    test_full_range(x_start, x_stop, y_start, y_stop,
+                    mpfr::RoundingMode::Upward);
+
+    std::cout << "-- Testing for FE_DOWNWARD in x range [0x" << std::hex
+              << x_start << ", 0x" << x_stop << ") y range [0x" << std::hex
+              << y_start << ", 0x" << y_stop << ") --" << std::dec << std::endl;
+    test_full_range(x_start, x_stop, y_start, y_stop,
+                    mpfr::RoundingMode::Downward);
+
+    std::cout << "-- Testing for FE_TOWARDZERO in x range [0x" << std::hex
+              << x_start << ", 0x" << x_stop << ") y range [0x" << std::hex
+              << y_start << ", 0x" << y_stop << ") --" << std::dec << std::endl;
+    test_full_range(x_start, x_stop, y_start, y_stop,
+                    mpfr::RoundingMode::TowardZero);
+  };
+};
+
 template <typename FloatType, mpfr::Operation Op, UnaryOp<FloatType> Func>
 using LlvmLibcUnaryOpExhaustiveMathTest =
     LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, FloatType, Op, Func>>;
@@ -175,3 +321,7 @@ template <typename OutType, typename InType, mpfr::Operation Op,
           UnaryOp<OutType, InType> Func>
 using LlvmLibcUnaryNarrowingOpExhaustiveMathTest =
     LlvmLibcExhaustiveMathTest<UnaryOpChecker<OutType, InType, Op, Func>>;
+
+template <typename FloatType, mpfr::Operation Op, BinaryOp<FloatType> Func>
+using LlvmLibcBinaryOpExhaustiveMathTest =
+    LlvmLibcBinaryInputExhaustiveMathTest<BinaryOpChecker<FloatType, Op, Func>>;
diff --git a/libc/test/src/math/exhaustive/fmodf16_test.cpp b/libc/test/src/math/exhaustive/fmodf16_test.cpp
new file mode 100644
index 0000000000000..067cec969a4f7
--- /dev/null
+++ b/libc/test/src/math/exhaustive/fmodf16_test.cpp
@@ -0,0 +1,41 @@
+//===-- Exhaustive test for fmodf16 ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "exhaustive_test.h"
+#include "src/math/fmodf16.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
+
+using LlvmLibcFmodf16ExhaustiveTest =
+    LlvmLibcBinaryOpExhaustiveMathTest<float16, mpfr::Operation::Fmod,
+                                       LIBC_NAMESPACE::fmodf16>;
+
+// Range: [0, Inf];
+static constexpr uint16_t POS_START = 0x0000U;
+static constexpr uint16_t POS_STOP = 0x7c00U;
+
+// Range: [-Inf, 0];
+static constexpr uint16_t NEG_START = 0x8000U;
+static constexpr uint16_t NEG_STOP = 0xfc00U;
+
+TEST_F(LlvmLibcFmodf16ExhaustiveTest, PostivePositiveRange) {
+  test_full_range_all_roundings(POS_START, POS_STOP, POS_START, POS_STOP);
+}
+
+TEST_F(LlvmLibcFmodf16ExhaustiveTest, PostiveNegativeRange) {
+  test_full_range_all_roundings(POS_START, POS_STOP, NEG_START, NEG_STOP);
+}
+
+TEST_F(LlvmLibcFmodf16ExhaustiveTest, NegativePositiveRange) {
+  test_full_range_all_roundings(NEG_START, NEG_STOP, POS_START, POS_STOP);
+}
+
+TEST_F(LlvmLibcFmodf16ExhaustiveTest, NegativeNegativeRange) {
+  test_full_range_all_roundings(NEG_START, NEG_STOP, POS_START, POS_STOP);
+}
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp
index 521c2658b327a..88aef3e6e10c5 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp
@@ -930,6 +930,8 @@ explain_binary_operation_one_output_error(Operation,
                                           const BinaryInput<long double> &,
                                           long double, double, RoundingMode);
 #ifdef LIBC_TYPES_HAS_FLOAT16
+template void explain_binary_operation_one_output_error(
+    Operation, const BinaryInput<float16> &, float16, double, RoundingMode);
 template void
 explain_binary_operation_one_output_error(Operation, const BinaryInput<float> &,
                                           float16, double, RoundingMode);
@@ -1088,6 +1090,10 @@ template bool
 compare_binary_operation_one_output(Operation, const BinaryInput<long double> &,
                                     long double, double, RoundingMode);
 #ifdef LIBC_TYPES_HAS_FLOAT16
+template bool compare_binary_operation_one_output(Operation,
+                                                  const BinaryInput<float16> &,
+                                                  float16, double,
+                                                  RoundingMode);
 template bool compare_binary_operation_one_output(Operation,
                                                   const BinaryInput<float> &,
                                                   float16, double,

>From bd989f90c7b9b3e23697b4e3bec6014bff5ad207 Mon Sep 17 00:00:00 2001
From: OverMighty <its.overmighty at gmail.com>
Date: Mon, 24 Jun 2024 22:34:29 +0200
Subject: [PATCH 3/3] fixup! [libc][math][c23] Add MPFR exhaustive test for
 fmodf16

---
 .../src/math/exhaustive/exhaustive_test.h     | 186 ++++++------------
 1 file changed, 65 insertions(+), 121 deletions(-)

diff --git a/libc/test/src/math/exhaustive/exhaustive_test.h b/libc/test/src/math/exhaustive/exhaustive_test.h
index 9afbdabb7fd1f..bf9be90f39e2c 100644
--- a/libc/test/src/math/exhaustive/exhaustive_test.h
+++ b/libc/test/src/math/exhaustive/exhaustive_test.h
@@ -69,22 +69,22 @@ struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
   }
 };
 
-template <typename T> using BinaryOp = T(T, T);
+template <typename OutType, typename InType = OutType>
+using BinaryOp = OutType(InType, InType);
 
-template <typename T, mpfr::Operation Op, BinaryOp<T> Func>
+template <typename OutType, typename InType, mpfr::Operation Op,
+          BinaryOp<OutType, InType> Func>
 struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
-  using FloatType = T;
+  using FloatType = InType;
   using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
   using StorageType = typename FPBits::StorageType;
 
-  static constexpr BinaryOp<FloatType> *FUNC = Func;
-
   // Check in a range, return the number of failures.
   uint64_t check(StorageType x_start, StorageType x_stop, StorageType y_start,
                  StorageType y_stop, mpfr::RoundingMode rounding) {
     mpfr::ForceRoundingMode r(rounding);
     if (!r.success)
-      return (x_stop > x_start || y_stop > y_start);
+      return x_stop > x_start || y_stop > y_start;
     StorageType xbits = x_start;
     uint64_t failed = 0;
     do {
@@ -93,12 +93,12 @@ struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
       do {
         FloatType y = FPBits(ybits).get_val();
         mpfr::BinaryInput<FloatType> input{x, y};
-        bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, input, FUNC(x, y),
+        bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, input, Func(x, y),
                                                          0.5, rounding);
         failed += (!correct);
         // Uncomment to print out failed values.
         // if (!correct) {
-        //   TEST_MPFR_MATCH(Op::Operation, x, Op::func(x, y), 0.5, rounding);
+        //   EXPECT_MPFR_MATCH_ROUNDING(Op, input, Func(x, y), 0.5, rounding);
         // }
       } while (ybits++ < y_stop);
     } while (xbits++ < x_stop);
@@ -108,7 +108,7 @@ struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
 
 // Checker class needs inherit from LIBC_NAMESPACE::testing::Test and provide
 //   StorageType and check method.
-template <typename Checker>
+template <typename Checker, size_t Increment = 1 << 20>
 struct LlvmLibcExhaustiveMathTest
     : public virtual LIBC_NAMESPACE::testing::Test,
       public Checker {
@@ -116,12 +116,37 @@ struct LlvmLibcExhaustiveMathTest
   using FPBits = typename Checker::FPBits;
   using StorageType = typename Checker::StorageType;
 
-  static constexpr StorageType INCREMENT = (1 << 20);
+  static constexpr StorageType INCREMENT = Increment;
+
+  void explain_failed_range(std::stringstream &msg, StorageType x_begin,
+                            StorageType x_end) {
+#ifdef LIBC_TYPES_HAS_FLOAT16
+    using T = LIBC_NAMESPACE::cpp::conditional_t<
+        LIBC_NAMESPACE::cpp::is_same_v<FloatType, float16>, float, FloatType>;
+#else
+    using T = FloatType;
+#endif
+
+    msg << x_begin << " to " << x_end << " [0x" << std::hex << x_begin << ", 0x"
+        << x_end << "), [" << std::hexfloat
+        << static_cast<T>(FPBits(x_begin).get_val()) << ", "
+        << static_cast<T>(FPBits(x_end).get_val()) << ")";
+  }
+
+  void explain_failed_range(std::stringstream &msg, StorageType x_begin,
+                            StorageType x_end, StorageType y_begin,
+                            StorageType y_end) {
+    msg << "x ";
+    explain_failed_range(msg, x_begin, x_end);
+    msg << ", y ";
+    explain_failed_range(msg, y_begin, y_end);
+  }
 
   // Break [start, stop) into `nthreads` subintervals and apply *check to each
   // subinterval in parallel.
-  void test_full_range(StorageType start, StorageType stop,
-                       mpfr::RoundingMode rounding) {
+  template <typename... T>
+  void test_full_range(mpfr::RoundingMode rounding, StorageType start,
+                       StorageType stop, T... extra_range_bounds) {
     int n_threads = std::thread::hardware_concurrency();
     std::vector<std::thread> thread_list;
     std::mutex mx_cur_val;
@@ -158,15 +183,14 @@ struct LlvmLibcExhaustiveMathTest
             std::cout << msg.str() << std::flush;
           }
 
-          uint64_t failed_in_range =
-              Checker::check(range_begin, range_end, rounding);
+          uint64_t failed_in_range = Checker::check(
+              range_begin, range_end, extra_range_bounds..., rounding);
           if (failed_in_range > 0) {
             std::stringstream msg;
             msg << "Test failed for " << std::dec << failed_in_range
-                << " inputs in range: " << range_begin << " to " << range_end
-                << " [0x" << std::hex << range_begin << ", 0x" << range_end
-                << "), [" << std::hexfloat << FPBits(range_begin).get_val()
-                << ", " << FPBits(range_end).get_val() << ")\n";
+                << " inputs in range: ";
+            explain_failed_range(msg, start, stop, extra_range_bounds...);
+            msg << "\n";
             std::cerr << msg.str() << std::flush;
 
             failed.fetch_add(failed_in_range);
@@ -189,127 +213,46 @@ struct LlvmLibcExhaustiveMathTest
   void test_full_range_all_roundings(StorageType start, StorageType stop) {
     std::cout << "-- Testing for FE_TONEAREST in range [0x" << std::hex << start
               << ", 0x" << stop << ") --" << std::dec << std::endl;
-    test_full_range(start, stop, mpfr::RoundingMode::Nearest);
+    test_full_range(mpfr::RoundingMode::Nearest, start, stop);
 
     std::cout << "-- Testing for FE_UPWARD in range [0x" << std::hex << start
               << ", 0x" << stop << ") --" << std::dec << std::endl;
-    test_full_range(start, stop, mpfr::RoundingMode::Upward);
+    test_full_range(mpfr::RoundingMode::Upward, start, stop);
 
     std::cout << "-- Testing for FE_DOWNWARD in range [0x" << std::hex << start
               << ", 0x" << stop << ") --" << std::dec << std::endl;
-    test_full_range(start, stop, mpfr::RoundingMode::Downward);
+    test_full_range(mpfr::RoundingMode::Downward, start, stop);
 
     std::cout << "-- Testing for FE_TOWARDZERO in range [0x" << std::hex
               << start << ", 0x" << stop << ") --" << std::dec << std::endl;
-    test_full_range(start, stop, mpfr::RoundingMode::TowardZero);
+    test_full_range(mpfr::RoundingMode::TowardZero, start, stop);
   };
-};
-
-template <typename Checker>
-struct LlvmLibcBinaryInputExhaustiveMathTest
-    : public virtual LIBC_NAMESPACE::testing::Test,
-      public Checker {
-  using FloatType = typename Checker::FloatType;
-  using FPBits = typename Checker::FPBits;
-  using StorageType = typename Checker::StorageType;
-
-  static constexpr StorageType Increment = (1 << 2);
-
-  // Break [start, stop) into `nthreads` subintervals and apply *check to each
-  // subinterval in parallel.
-  void test_full_range(StorageType x_start, StorageType x_stop,
-                       StorageType y_start, StorageType y_stop,
-                       mpfr::RoundingMode rounding) {
-    int n_threads = std::thread::hardware_concurrency();
-    std::vector<std::thread> thread_list;
-    std::mutex mx_cur_val;
-    int current_percent = -1;
-    StorageType current_value = x_start;
-    std::atomic<uint64_t> failed(0);
-
-    for (int i = 0; i < n_threads; ++i) {
-      thread_list.emplace_back([&, this]() {
-        while (true) {
-          StorageType range_begin, range_end;
-          int new_percent = -1;
-          {
-            std::lock_guard<std::mutex> lock(mx_cur_val);
-            if (current_value == x_stop)
-              return;
-
-            range_begin = current_value;
-            if (x_stop >= Increment && x_stop - Increment >= current_value) {
-              range_end = current_value + Increment;
-            } else {
-              range_end = x_stop;
-            }
-            current_value = range_end;
-            int pc = 100.0 * (range_end - x_start) / (x_stop - x_start);
-            if (current_percent != pc) {
-              new_percent = pc;
-              current_percent = pc;
-            }
-          }
-          if (new_percent >= 0) {
-            std::stringstream msg;
-            msg << new_percent << "% is in process     \r";
-            std::cout << msg.str() << std::flush;
-          }
-
-          uint64_t failed_in_range =
-              Checker::check(range_begin, range_end, y_start, y_stop, rounding);
-          if (failed_in_range > 0) {
-            using T = LIBC_NAMESPACE::cpp::conditional_t<
-                LIBC_NAMESPACE::cpp::is_same_v<FloatType, float16>, float,
-                FloatType>;
-            std::stringstream msg;
-            msg << "Test failed for " << std::dec << failed_in_range
-                << " inputs in range: " << range_begin << " to " << range_end
-                << " [0x" << std::hex << range_begin << ", 0x" << range_end
-                << "), [" << std::hexfloat
-                << static_cast<T>(FPBits(range_begin).get_val()) << ", "
-                << static_cast<T>(FPBits(range_end).get_val()) << ")\n";
-            std::cerr << msg.str() << std::flush;
-
-            failed.fetch_add(failed_in_range);
-          }
-        }
-      });
-    }
-
-    for (auto &thread : thread_list) {
-      if (thread.joinable()) {
-        thread.join();
-      }
-    }
-
-    std::cout << std::endl;
-    std::cout << "Test " << ((failed > 0) ? "FAILED" : "PASSED") << std::endl;
-    ASSERT_EQ(failed.load(), uint64_t(0));
-  }
 
   void test_full_range_all_roundings(StorageType x_start, StorageType x_stop,
                                      StorageType y_start, StorageType y_stop) {
-    test_full_range(x_start, x_stop, y_start, y_stop,
-                    mpfr::RoundingMode::Nearest);
+    std::cout << "-- Testing for FE_TONEAREST in x range [0x" << std::hex
+              << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
+              << ", 0x" << y_stop << ") --" << std::dec << std::endl;
+    test_full_range(mpfr::RoundingMode::Nearest, x_start, x_stop, y_start,
+                    y_stop);
 
     std::cout << "-- Testing for FE_UPWARD in x range [0x" << std::hex
-              << x_start << ", 0x" << x_stop << ") y range [0x" << std::hex
-              << y_start << ", 0x" << y_stop << ") --" << std::dec << std::endl;
-    test_full_range(x_start, x_stop, y_start, y_stop,
-                    mpfr::RoundingMode::Upward);
+              << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
+              << ", 0x" << y_stop << ") --" << std::dec << std::endl;
+    test_full_range(mpfr::RoundingMode::Upward, x_start, x_stop, y_start,
+                    y_stop);
 
     std::cout << "-- Testing for FE_DOWNWARD in x range [0x" << std::hex
-              << x_start << ", 0x" << x_stop << ") y range [0x" << std::hex
-              << y_start << ", 0x" << y_stop << ") --" << std::dec << std::endl;
-    test_full_range(x_start, x_stop, y_start, y_stop,
-                    mpfr::RoundingMode::Downward);
+              << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
+              << ", 0x" << y_stop << ") --" << std::dec << std::endl;
+    test_full_range(mpfr::RoundingMode::Downward, x_start, x_stop, y_start,
+                    y_stop);
 
     std::cout << "-- Testing for FE_TOWARDZERO in x range [0x" << std::hex
-              << x_start << ", 0x" << x_stop << ") y range [0x" << std::hex
-              << y_start << ", 0x" << y_stop << ") --" << std::dec << std::endl;
-    test_full_range(x_start, x_stop, y_start, y_stop,
-                    mpfr::RoundingMode::TowardZero);
+              << x_start << ", 0x" << x_stop << "), y range [0x" << y_start
+              << ", 0x" << y_stop << ") --" << std::dec << std::endl;
+    test_full_range(mpfr::RoundingMode::TowardZero, x_start, x_stop, y_start,
+                    y_stop);
   };
 };
 
@@ -324,4 +267,5 @@ using LlvmLibcUnaryNarrowingOpExhaustiveMathTest =
 
 template <typename FloatType, mpfr::Operation Op, BinaryOp<FloatType> Func>
 using LlvmLibcBinaryOpExhaustiveMathTest =
-    LlvmLibcBinaryInputExhaustiveMathTest<BinaryOpChecker<FloatType, Op, Func>>;
+    LlvmLibcExhaustiveMathTest<BinaryOpChecker<FloatType, FloatType, Op, Func>,
+                               1 << 2>;



More information about the libc-commits mailing list