[libc-commits] [libc] [libc][math] Add MPFR tests for fmul (PR #96413)

via libc-commits libc-commits at lists.llvm.org
Sat Jun 22 17:54:59 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-libc

Author: Job Henandez Lara (Jobhdez)

<details>
<summary>Changes</summary>

We are adding these tests to make sure `fmul` really works before we add `dmul` and `fmull`.



---
Full diff: https://github.com/llvm/llvm-project/pull/96413.diff


6 Files Affected:

- (modified) libc/test/src/math/CMakeLists.txt (+12) 
- (added) libc/test/src/math/FMulTest.h (+73) 
- (added) libc/test/src/math/fmul_test.cpp (+13) 
- (modified) libc/test/src/math/smoke/FMulTest.h (+46-1) 
- (modified) libc/utils/MPFRWrapper/MPFRUtils.cpp (+26-14) 
- (modified) libc/utils/MPFRWrapper/MPFRUtils.h (+14-8) 


``````````diff
diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt
index bb364c3f0a175..6e6da9f0cef14 100644
--- a/libc/test/src/math/CMakeLists.txt
+++ b/libc/test/src/math/CMakeLists.txt
@@ -1729,6 +1729,18 @@ add_fp_unittest(
     libc.src.__support.FPUtil.fp_bits
 )
 
+add_fp_unittest(
+  fmul_test
+  NEED_MPFR
+  SUITE
+    libc-math-unittests
+  SRCS
+    fmul_test.cpp
+  HDRS
+    FMulTest.h
+  DEPENDS
+    libc.src.math.fmul
+)
 add_fp_unittest(
   asinhf_test
   NEED_MPFR
diff --git a/libc/test/src/math/FMulTest.h b/libc/test/src/math/FMulTest.h
new file mode 100644
index 0000000000000..b2fec1f856c5a
--- /dev/null
+++ b/libc/test/src/math/FMulTest.h
@@ -0,0 +1,73 @@
+//===-- Utility class to test fmul[f|l] ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_TEST_SRC_MATH_FMULTEST_H
+#define LLVM_LIBC_TEST_SRC_MATH_FMULTEST_H
+
+#include "test/UnitTest/FEnvSafeTest.h"
+#include "test/UnitTest/FPMatcher.h"
+#include "test/UnitTest/Test.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
+
+template <typename OutType, typename InType>
+class FmulMPFRTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
+
+  DECLARE_SPECIAL_CONSTANTS(InType)
+
+public:
+  typedef OutType (*FMulFunc)(InType, InType);
+
+   void testFMulMPFR(FMulFunc func) {
+     constexpr int N = 10;
+     mpfr::BinaryInput<InType> INPUTS[N] = {
+       {3.0, 5.0}, {0x1.0p1, 0x1.0p-131}, {0x1.0p2, 0x1.0p-129},
+       {1.0,1.0}, {-0.0, -0.0}, {-0.0, 0.0}, {0.0, -0.0},
+       {0x1.0p100, 0x1.0p100},
+       {1.0, 1.0 + 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150},
+       {1.0, 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150}
+     };
+
+     for (int i = 0; i < N; ++i) {
+       InType x = INPUTS[i].x;
+       InType y = INPUTS[i].y;
+       ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fmul, INPUTS[i], func(x,y), 0.5);
+     }
+   }
+
+  void testSpecialInputsMPFR(FMulFunc func) {
+    constexpr int N = 27;
+    mpfr::BinaryInput<InType> INPUTS[N] = {
+        {inf, 0x1.0p-129}, {0x1.0p-129, inf}, {inf, 2.0}, {3.0, inf}, {0.0, 0.0},
+        {neg_inf, aNaN}, {aNaN, neg_inf}, {neg_inf, neg_inf},
+        {0.0, neg_inf}, {neg_inf, 0.0},
+        {neg_inf, 1.0}, {1.0, neg_inf},
+        {neg_inf, 0x1.0p-129}, {0x1.0p-129, neg_inf},
+        {0.0, 0x1.0p-129}, {inf, 0.0}, {0.0, inf},
+        {0.0, aNaN}, {2.0, aNaN}, {0x1.0p-129, aNaN}, {inf, aNaN}, {aNaN, aNaN},
+        {0.0, sNaN}, {2.0, sNaN}, {0x1.0p-129, sNaN}, {inf, sNaN}, {sNaN, sNaN}
+    };
+
+
+    for (int i = 0; i < N; ++i) {
+        InType x = INPUTS[i].x;
+        InType y = INPUTS[i].y;
+        ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fmul, INPUTS[i], func(x, y), 0.5);
+    }
+}
+
+};
+
+#define LIST_FMUL_MPFR_TESTS(OutType, InType, func)			\
+  using LlvmLibcFmulTest = FmulMPFRTest<OutType, InType>;                                     \
+  TEST_F(LlvmLibcFmulTest, MulMpfr) { testFMulMPFR(&func); }                       \
+  TEST_F(LlvmLibcFmulTest, NanInfMpfr) { testSpecialInputsMPFR(&func); }          
+  
+
+#endif // LLVM_LIBC_TEST_SRC_MATH_FMULTEST_H
diff --git a/libc/test/src/math/fmul_test.cpp b/libc/test/src/math/fmul_test.cpp
new file mode 100644
index 0000000000000..16eaa1a818daf
--- /dev/null
+++ b/libc/test/src/math/fmul_test.cpp
@@ -0,0 +1,13 @@
+//===-- Unittests for fmul-------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#include "FMulTest.h"
+
+#include "src/math/fmul.h"
+
+LIST_FMUL_MPFR_TESTS(float, double, LIBC_NAMESPACE::fmul)
diff --git a/libc/test/src/math/smoke/FMulTest.h b/libc/test/src/math/smoke/FMulTest.h
index 33fb82c8d2da1..40d4b4d930326 100644
--- a/libc/test/src/math/smoke/FMulTest.h
+++ b/libc/test/src/math/smoke/FMulTest.h
@@ -12,6 +12,9 @@
 #include "test/UnitTest/FEnvSafeTest.h"
 #include "test/UnitTest/FPMatcher.h"
 #include "test/UnitTest/Test.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
 
 template <typename T, typename R>
 class FmulTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
@@ -94,11 +97,53 @@ class FmulTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
     EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(inf, sNaN));
     EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(sNaN, sNaN));
   }
+
+   void testMPFR(FMulFunc func) {
+     constexpr int N = 10;
+     mpfr::BinaryInput<T> INPUTS[N] = {
+       {3.0, 5.0}, {0x1.0p1, 0x1.0p-131}, {0x1.0p2, 0x1.0p-129},
+       {1.0,1.0}, {-0.0, -0.0}, {-0.0, 0.0}, {0.0, -0.0},
+       {0x1.0p100, 0x1.0p100},
+       {1.0, 1.0 + 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150},
+       {1.0, 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150}
+     };
+
+     for (int i = 0; i < N; ++i) {
+       T x = INPUTS[i].x;
+       T y = INPUTS[i].y;
+       ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fmul, INPUTS[i], LIBC_NAMESPACE::fmul(x,y), 0.5);
+     }
+   }
+
+  void testSpecialInputsMPFR(FMulFunc func) {
+    constexpr int N = 27;
+    mpfr::BinaryInput<T> INPUTS[N] = {
+        {inf, 0x1.0p-129}, {0x1.0p-129, inf}, {inf, 2.0}, {3.0, inf}, {0.0, 0.0},
+        {neg_inf, aNaN}, {aNaN, neg_inf}, {neg_inf, neg_inf},
+        {0.0, neg_inf}, {neg_inf, 0.0},
+        {neg_inf, 1.0}, {1.0, neg_inf},
+        {neg_inf, 0x1.0p-129}, {0x1.0p-129, neg_inf},
+        {0.0, 0x1.0p-129}, {inf, 0.0}, {0.0, inf},
+        {0.0, aNaN}, {2.0, aNaN}, {0x1.0p-129, aNaN}, {inf, aNaN}, {aNaN, aNaN},
+        {0.0, sNaN}, {2.0, sNaN}, {0x1.0p-129, sNaN}, {inf, sNaN}, {sNaN, sNaN}
+    };
+
+
+    for (int i = 0; i < N; ++i) {
+        T x = INPUTS[i].x;
+        T y = INPUTS[i].y;
+        ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fmul, INPUTS[i], LIBC_NAMESPACE::fmul(x, y), 0.5);
+    }
+}
+
 };
 
 #define LIST_FMUL_TESTS(T, R, func)                                            \
   using LlvmLibcFmulTest = FmulTest<T, R>;                                     \
   TEST_F(LlvmLibcFmulTest, Mul) { testMul(&func); }                            \
-  TEST_F(LlvmLibcFmulTest, NaNInf) { testSpecialInputs(&func); }
+  TEST_F(LlvmLibcFmulTest, NaNInf) { testSpecialInputs(&func); }               \
+  TEST_F(LlvmLibcFmulTest, MulMpfr) { testMPFR(&func); }                       \
+  TEST_F(LlvmLibcFmulTest, NanInfMpfr) { testSpecialInputsMPFR(&func); }          
+  
 
 #endif // LLVM_LIBC_TEST_SRC_MATH_SMOKE_FMULTEST_H
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp
index 2eac4dd8e199d..095f2244c2362 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp
@@ -467,6 +467,12 @@ class MPFRNumber {
     return result;
   }
 
+  MPFRNumber fmul(const MPFRNumber &b) {
+    MPFRNumber result(*this);
+    mpfr_mul(result.value, value, b.value, mpfr_rounding);
+    return result;
+  }
+
   cpp::string str() const {
     // 200 bytes should be more than sufficient to hold a 100-digit number
     // plus additional bytes for the decimal point, '-' sign etc.
@@ -714,6 +720,8 @@ binary_operation_one_output(Operation op, InputType x, InputType y,
     return inputX.hypot(inputY);
   case Operation::Pow:
     return inputX.pow(inputY);
+  case Operation::Fmul:
+    return inputX.fmul(inputY);
   default:
     __builtin_unreachable();
   }
@@ -885,13 +893,13 @@ template void explain_binary_operation_two_outputs_error<long double>(
     Operation, const BinaryInput<long double> &,
     const BinaryOutput<long double> &, double, RoundingMode);
 
-template <typename T>
+  template <typename T, typename R>
 void explain_binary_operation_one_output_error(Operation op,
                                                const BinaryInput<T> &input,
-                                               T libc_result,
+                                               R libc_result,
                                                double ulp_tolerance,
                                                RoundingMode rounding) {
-  unsigned int precision = get_precision<T>(ulp_tolerance);
+  unsigned int precision = get_precision<R>(ulp_tolerance);
   MPFRNumber mpfrX(input.x, precision);
   MPFRNumber mpfrY(input.y, precision);
   FPBits<T> xbits(input.x);
@@ -906,19 +914,21 @@ void explain_binary_operation_one_output_error(Operation op,
 
   tlog << "Libc result: " << mpfrMatchValue.str() << '\n'
        << "MPFR result: " << mpfr_result.str() << '\n';
-  tlog << "Libc floating point result bits: " << str(FPBits<T>(libc_result))
+  tlog << "Libc floating point result bits: " << str(FPBits<R>(libc_result))
        << '\n';
   tlog << "              MPFR rounded bits: "
-       << str(FPBits<T>(mpfr_result.as<T>())) << '\n';
+       << str(FPBits<R>(mpfr_result.as<R>())) << '\n';
   tlog << "ULP error: " << mpfr_result.ulp_as_mpfr_number(libc_result).str()
        << '\n';
 }
 
-template void explain_binary_operation_one_output_error<float>(
+template void explain_binary_operation_one_output_error(
     Operation, const BinaryInput<float> &, float, double, RoundingMode);
-template void explain_binary_operation_one_output_error<double>(
+template void explain_binary_operation_one_output_error(
     Operation, const BinaryInput<double> &, double, double, RoundingMode);
-template void explain_binary_operation_one_output_error<long double>(
+template void explain_binary_operation_one_output_error(
+    Operation, const BinaryInput<double> &, float, double, RoundingMode);
+template void explain_binary_operation_one_output_error(
     Operation, const BinaryInput<long double> &, long double, double,
     RoundingMode);
 
@@ -1051,12 +1061,12 @@ template bool compare_binary_operation_two_outputs<long double>(
     Operation, const BinaryInput<long double> &,
     const BinaryOutput<long double> &, double, RoundingMode);
 
-template <typename T>
+  template <typename T, typename R>
 bool compare_binary_operation_one_output(Operation op,
                                          const BinaryInput<T> &input,
-                                         T libc_result, double ulp_tolerance,
+                                         R libc_result, double ulp_tolerance,
                                          RoundingMode rounding) {
-  unsigned int precision = get_precision<T>(ulp_tolerance);
+  unsigned int precision = get_precision<R>(ulp_tolerance);
   MPFRNumber mpfr_result =
       binary_operation_one_output(op, input.x, input.y, precision, rounding);
   double ulp = mpfr_result.ulp(libc_result);
@@ -1064,11 +1074,13 @@ bool compare_binary_operation_one_output(Operation op,
   return (ulp <= ulp_tolerance);
 }
 
-template bool compare_binary_operation_one_output<float>(
+template bool compare_binary_operation_one_output(
     Operation, const BinaryInput<float> &, float, double, RoundingMode);
-template bool compare_binary_operation_one_output<double>(
+template bool compare_binary_operation_one_output(
     Operation, const BinaryInput<double> &, double, double, RoundingMode);
-template bool compare_binary_operation_one_output<long double>(
+template bool compare_binary_operation_one_output(
+    Operation, const BinaryInput<double> &, float, double, RoundingMode);
+template bool compare_binary_operation_one_output(
     Operation, const BinaryInput<long double> &, long double, double,
     RoundingMode);
 
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h
index 0b4f42a72ec81..0789a241e4020 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.h
+++ b/libc/utils/MPFRWrapper/MPFRUtils.h
@@ -41,6 +41,7 @@ enum class Operation : int {
   Exp10,
   Expm1,
   Floor,
+  Fmul,
   Log,
   Log2,
   Log10,
@@ -152,17 +153,17 @@ bool compare_unary_operation_two_outputs(Operation op, T input,
                                          const BinaryOutput<T> &libc_output,
                                          double ulp_tolerance,
                                          RoundingMode rounding);
-template <typename T>
+  template <typename T>
 bool compare_binary_operation_two_outputs(Operation op,
                                           const BinaryInput<T> &input,
                                           const BinaryOutput<T> &libc_output,
                                           double ulp_tolerance,
                                           RoundingMode rounding);
 
-template <typename T>
+  template <typename T, typename R>
 bool compare_binary_operation_one_output(Operation op,
                                          const BinaryInput<T> &input,
-                                         T libc_output, double ulp_tolerance,
+                                         R libc_output, double ulp_tolerance,
                                          RoundingMode rounding);
 
 template <typename InputType, typename OutputType>
@@ -187,10 +188,10 @@ void explain_binary_operation_two_outputs_error(
     const BinaryOutput<T> &match_value, double ulp_tolerance,
     RoundingMode rounding);
 
-template <typename T>
+  template <typename T, typename R>
 void explain_binary_operation_one_output_error(Operation op,
                                                const BinaryInput<T> &input,
-                                               T match_value,
+                                               R match_value,
                                                double ulp_tolerance,
                                                RoundingMode rounding);
 
@@ -235,7 +236,7 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
                                                rounding);
   }
 
-  template <typename T> bool match(const BinaryInput<T> &in, T out) {
+  template <typename T, typename R> bool match(const BinaryInput<T> &in, R out) {
     return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
                                                rounding);
   }
@@ -268,7 +269,7 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
                                                rounding);
   }
 
-  template <typename T> void explain_error(const BinaryInput<T> &in, T out) {
+  template <typename T, typename R> void explain_error(const BinaryInput<T> &in, R out) {
     explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
                                               rounding);
   }
@@ -293,7 +294,8 @@ constexpr bool is_valid_operation() {
       (op == Operation::Fma && internal::IsTernaryInput<InputType>::VALUE &&
        cpp::is_floating_point_v<
            typename internal::MakeScalarInput<InputType>::type> &&
-       cpp::is_floating_point_v<OutputType>);
+       cpp::is_floating_point_v<OutputType>) ||
+    (op == Operation::Fmul && !internal::AreMatchingBinaryInputAndBinaryOutput<InputType, OutputType>::VALUE);
   if (IS_NARROWING_OP)
     return true;
   return (Operation::BeginUnaryOperationsSingleOutput < op &&
@@ -308,6 +310,10 @@ constexpr bool is_valid_operation() {
           op < Operation::EndBinaryOperationsSingleOutput &&
           cpp::is_floating_point_v<OutputType> &&
           cpp::is_same_v<InputType, BinaryInput<OutputType>>) ||
+         (Operation::BeginBinaryOperationsSingleOutput < op &&
+          op < Operation::EndBinaryOperationsSingleOutput &&
+          cpp::is_floating_point_v<OutputType> &&
+          !cpp::is_same_v<InputType, BinaryInput<OutputType>>) ||
          (Operation::BeginBinaryOperationsTwoOutputs < op &&
           op < Operation::EndBinaryOperationsTwoOutputs &&
           internal::AreMatchingBinaryInputAndBinaryOutput<InputType,

``````````

</details>


https://github.com/llvm/llvm-project/pull/96413


More information about the libc-commits mailing list