[libc-commits] [libc] cce6507 - [libc] Add rounding mode support for MPFR testing macros.

Tue Ly via libc-commits libc-commits at lists.llvm.org
Thu Jan 13 10:29:11 PST 2022


Author: Tue Ly
Date: 2022-01-13T13:28:50-05:00
New Revision: cce650776722ba36d81941e0f2e31a9602c88682

URL: https://github.com/llvm/llvm-project/commit/cce650776722ba36d81941e0f2e31a9602c88682
DIFF: https://github.com/llvm/llvm-project/commit/cce650776722ba36d81941e0f2e31a9602c88682.diff

LOG: [libc] Add rounding mode support for MPFR testing macros.

Add an extra argument for rounding mode to EXPECT_MPFR_MATCH and ASSERT_MPFR_MATCH macros.

Reviewed By: sivachandra, michaelrj

Differential Revision: https://reviews.llvm.org/D116777

Added: 
    

Modified: 
    libc/test/src/math/SqrtTest.h
    libc/utils/MPFRWrapper/MPFRUtils.cpp
    libc/utils/MPFRWrapper/MPFRUtils.h

Removed: 
    


################################################################################
diff  --git a/libc/test/src/math/SqrtTest.h b/libc/test/src/math/SqrtTest.h
index f3e4def723aaf..79306a60e2baa 100644
--- a/libc/test/src/math/SqrtTest.h
+++ b/libc/test/src/math/SqrtTest.h
@@ -1,4 +1,4 @@
-//===-- Utility class to test fabs[f|l] -------------------------*- C++ -*-===//
+//===-- Utility class to test sqrt[f|l] -------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -24,7 +24,7 @@ template <typename T> class SqrtTest : public __llvm_libc::testing::Test {
 public:
   typedef T (*SqrtFunc)(T);
 
-  void testSpecialNumbers(SqrtFunc func) {
+  void test_special_numbers(SqrtFunc func) {
     ASSERT_FP_EQ(aNaN, func(aNaN));
     ASSERT_FP_EQ(inf, func(inf));
     ASSERT_FP_EQ(aNaN, func(neg_inf));
@@ -36,24 +36,23 @@ template <typename T> class SqrtTest : public __llvm_libc::testing::Test {
     ASSERT_FP_EQ(T(3.0), func(T(9.0)));
   }
 
-  void testDenormalValues(SqrtFunc func) {
+  void test_denormal_values(SqrtFunc func) {
     for (UIntType mant = 1; mant < HIDDEN_BIT; mant <<= 1) {
       FPBits denormal(T(0.0));
       denormal.set_mantissa(mant);
 
-      ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, T(denormal), func(T(denormal)),
-                        T(0.5));
+      test_all_rounding_modes(func, T(denormal));
     }
 
     constexpr UIntType COUNT = 1'000'001;
     constexpr UIntType STEP = HIDDEN_BIT / COUNT;
     for (UIntType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
       T x = *reinterpret_cast<T *>(&v);
-      ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5);
+      test_all_rounding_modes(func, x);
     }
   }
 
-  void testNormalRange(SqrtFunc func) {
+  void test_normal_range(SqrtFunc func) {
     constexpr UIntType COUNT = 10'000'001;
     constexpr UIntType STEP = UIntType(-1) / COUNT;
     for (UIntType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
@@ -61,13 +60,31 @@ template <typename T> class SqrtTest : public __llvm_libc::testing::Test {
       if (isnan(x) || (x < 0)) {
         continue;
       }
-      ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5);
+      test_all_rounding_modes(func, x);
     }
   }
+
+  void test_all_rounding_modes(SqrtFunc func, T x) {
+    mpfr::ForceRoundingMode r1(mpfr::RoundingMode::Nearest);
+    EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5,
+                      mpfr::RoundingMode::Nearest);
+
+    mpfr::ForceRoundingMode r2(mpfr::RoundingMode::Upward);
+    EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5,
+                      mpfr::RoundingMode::Upward);
+
+    mpfr::ForceRoundingMode r3(mpfr::RoundingMode::Downward);
+    EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5,
+                      mpfr::RoundingMode::Downward);
+
+    mpfr::ForceRoundingMode r4(mpfr::RoundingMode::TowardZero);
+    EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5,
+                      mpfr::RoundingMode::TowardZero);
+  }
 };
 
 #define LIST_SQRT_TESTS(T, func)                                               \
   using LlvmLibcSqrtTest = SqrtTest<T>;                                        \
-  TEST_F(LlvmLibcSqrtTest, SpecialNumbers) { testSpecialNumbers(&func); }      \
-  TEST_F(LlvmLibcSqrtTest, DenormalValues) { testDenormalValues(&func); }      \
-  TEST_F(LlvmLibcSqrtTest, NormalRange) { testNormalRange(&func); }
+  TEST_F(LlvmLibcSqrtTest, SpecialNumbers) { test_special_numbers(&func); }    \
+  TEST_F(LlvmLibcSqrtTest, DenormalValues) { test_denormal_values(&func); }    \
+  TEST_F(LlvmLibcSqrtTest, NormalRange) { test_normal_range(&func); }

diff  --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp
index 4e25c874d122f..7bb3a7fc5f598 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp
@@ -14,6 +14,7 @@
 #include "utils/UnitTest/FPMatcher.h"
 
 #include <cmath>
+#include <fenv.h>
 #include <memory>
 #include <stdint.h>
 #include <string>
@@ -55,141 +56,227 @@ template <> struct Precision<long double> {
 };
 #endif
 
+// A precision value which allows sufficiently large additional
+// precision compared to the floating point precision.
+template <typename T> struct ExtraPrecision;
+
+template <> struct ExtraPrecision<float> {
+  static constexpr unsigned int VALUE = 128;
+};
+
+template <> struct ExtraPrecision<double> {
+  static constexpr unsigned int VALUE = 256;
+};
+
+template <> struct ExtraPrecision<long double> {
+  static constexpr unsigned int VALUE = 256;
+};
+
+// If the ulp tolerance is less than or equal to 0.5, we would check that the
+// result is rounded correctly with respect to the rounding mode by using the
+// same precision as the inputs.
+template <typename T>
+static inline unsigned int get_precision(double ulp_tolerance) {
+  if (ulp_tolerance <= 0.5) {
+    return Precision<T>::VALUE;
+  } else {
+    return ExtraPrecision<T>::VALUE;
+  }
+}
+
+static inline mpfr_rnd_t get_mpfr_rounding_mode(RoundingMode mode) {
+  switch (mode) {
+  case RoundingMode::Upward:
+    return MPFR_RNDU;
+    break;
+  case RoundingMode::Downward:
+    return MPFR_RNDD;
+    break;
+  case RoundingMode::TowardZero:
+    return MPFR_RNDZ;
+    break;
+  case RoundingMode::Nearest:
+    return MPFR_RNDN;
+    break;
+  }
+}
+
+int get_fe_rounding(RoundingMode mode) {
+  switch (mode) {
+  case RoundingMode::Upward:
+    return FE_UPWARD;
+    break;
+  case RoundingMode::Downward:
+    return FE_DOWNWARD;
+    break;
+  case RoundingMode::TowardZero:
+    return FE_TOWARDZERO;
+    break;
+  case RoundingMode::Nearest:
+    return FE_TONEAREST;
+    break;
+  }
+}
+
+ForceRoundingMode::ForceRoundingMode(RoundingMode mode) {
+  old_rounding_mode = fegetround();
+  rounding_mode = get_fe_rounding(mode);
+  if (old_rounding_mode != rounding_mode)
+    fesetround(rounding_mode);
+}
+
+ForceRoundingMode::~ForceRoundingMode() {
+  if (old_rounding_mode != rounding_mode)
+    fesetround(old_rounding_mode);
+}
+
 class MPFRNumber {
-  // A precision value which allows sufficiently large additional
-  // precision even compared to quad-precision floating point values.
   unsigned int mpfr_precision;
+  mpfr_rnd_t mpfr_rounding;
 
   mpfr_t value;
 
 public:
-  MPFRNumber() : mpfr_precision(256) { mpfr_init2(value, mpfr_precision); }
+  MPFRNumber() : mpfr_precision(256), mpfr_rounding(MPFR_RNDN) {
+    mpfr_init2(value, mpfr_precision);
+  }
 
   // We use explicit EnableIf specializations to disallow implicit
   // conversions. Implicit conversions can potentially lead to loss of
   // precision.
   template <typename XType,
             cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0>
-  explicit MPFRNumber(XType x, int precision = 128)
-      : mpfr_precision(precision) {
+  explicit MPFRNumber(XType x, int precision = ExtraPrecision<XType>::VALUE,
+                      RoundingMode rounding = RoundingMode::Nearest)
+      : mpfr_precision(precision),
+        mpfr_rounding(get_mpfr_rounding_mode(rounding)) {
     mpfr_init2(value, mpfr_precision);
-    mpfr_set_flt(value, x, MPFR_RNDN);
+    mpfr_set_flt(value, x, mpfr_rounding);
   }
 
   template <typename XType,
             cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0>
-  explicit MPFRNumber(XType x, int precision = 128)
-      : mpfr_precision(precision) {
+  explicit MPFRNumber(XType x, int precision = ExtraPrecision<XType>::VALUE,
+                      RoundingMode rounding = RoundingMode::Nearest)
+      : mpfr_precision(precision),
+        mpfr_rounding(get_mpfr_rounding_mode(rounding)) {
     mpfr_init2(value, mpfr_precision);
-    mpfr_set_d(value, x, MPFR_RNDN);
+    mpfr_set_d(value, x, mpfr_rounding);
   }
 
   template <typename XType,
             cpp::EnableIfType<cpp::IsSame<long double, XType>::Value, int> = 0>
-  explicit MPFRNumber(XType x, int precision = 128)
-      : mpfr_precision(precision) {
+  explicit MPFRNumber(XType x, int precision = ExtraPrecision<XType>::VALUE,
+                      RoundingMode rounding = RoundingMode::Nearest)
+      : mpfr_precision(precision),
+        mpfr_rounding(get_mpfr_rounding_mode(rounding)) {
     mpfr_init2(value, mpfr_precision);
-    mpfr_set_ld(value, x, MPFR_RNDN);
+    mpfr_set_ld(value, x, mpfr_rounding);
   }
 
   template <typename XType,
             cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0>
-  explicit MPFRNumber(XType x, int precision = 128)
-      : mpfr_precision(precision) {
+  explicit MPFRNumber(XType x, int precision = ExtraPrecision<float>::VALUE,
+                      RoundingMode rounding = RoundingMode::Nearest)
+      : mpfr_precision(precision),
+        mpfr_rounding(get_mpfr_rounding_mode(rounding)) {
     mpfr_init2(value, mpfr_precision);
-    mpfr_set_sj(value, x, MPFR_RNDN);
+    mpfr_set_sj(value, x, mpfr_rounding);
   }
 
-  MPFRNumber(const MPFRNumber &other) : mpfr_precision(other.mpfr_precision) {
+  MPFRNumber(const MPFRNumber &other)
+      : mpfr_precision(other.mpfr_precision),
+        mpfr_rounding(other.mpfr_rounding) {
     mpfr_init2(value, mpfr_precision);
-    mpfr_set(value, other.value, MPFR_RNDN);
+    mpfr_set(value, other.value, mpfr_rounding);
   }
 
   ~MPFRNumber() { mpfr_clear(value); }
 
   MPFRNumber &operator=(const MPFRNumber &rhs) {
     mpfr_precision = rhs.mpfr_precision;
-    mpfr_set(value, rhs.value, MPFR_RNDN);
+    mpfr_rounding = rhs.mpfr_rounding;
+    mpfr_set(value, rhs.value, mpfr_rounding);
     return *this;
   }
 
   MPFRNumber abs() const {
-    MPFRNumber result;
-    mpfr_abs(result.value, value, MPFR_RNDN);
+    MPFRNumber result(*this);
+    mpfr_abs(result.value, value, mpfr_rounding);
     return result;
   }
 
   MPFRNumber ceil() const {
-    MPFRNumber result;
+    MPFRNumber result(*this);
     mpfr_ceil(result.value, value);
     return result;
   }
 
   MPFRNumber cos() const {
-    MPFRNumber result;
-    mpfr_cos(result.value, value, MPFR_RNDN);
+    MPFRNumber result(*this);
+    mpfr_cos(result.value, value, mpfr_rounding);
     return result;
   }
 
   MPFRNumber exp() const {
-    MPFRNumber result;
-    mpfr_exp(result.value, value, MPFR_RNDN);
+    MPFRNumber result(*this);
+    mpfr_exp(result.value, value, mpfr_rounding);
     return result;
   }
 
   MPFRNumber exp2() const {
-    MPFRNumber result;
-    mpfr_exp2(result.value, value, MPFR_RNDN);
+    MPFRNumber result(*this);
+    mpfr_exp2(result.value, value, mpfr_rounding);
     return result;
   }
 
   MPFRNumber expm1() const {
-    MPFRNumber result;
-    mpfr_expm1(result.value, value, MPFR_RNDN);
+    MPFRNumber result(*this);
+    mpfr_expm1(result.value, value, mpfr_rounding);
     return result;
   }
 
   MPFRNumber floor() const {
-    MPFRNumber result;
+    MPFRNumber result(*this);
     mpfr_floor(result.value, value);
     return result;
   }
 
   MPFRNumber frexp(int &exp) {
-    MPFRNumber result;
+    MPFRNumber result(*this);
     mpfr_exp_t resultExp;
-    mpfr_frexp(&resultExp, result.value, value, MPFR_RNDN);
+    mpfr_frexp(&resultExp, result.value, value, mpfr_rounding);
     exp = resultExp;
     return result;
   }
 
   MPFRNumber hypot(const MPFRNumber &b) {
-    MPFRNumber result;
-    mpfr_hypot(result.value, value, b.value, MPFR_RNDN);
+    MPFRNumber result(*this);
+    mpfr_hypot(result.value, value, b.value, mpfr_rounding);
     return result;
   }
 
   MPFRNumber log() const {
-    MPFRNumber result;
-    mpfr_log(result.value, value, MPFR_RNDN);
+    MPFRNumber result(*this);
+    mpfr_log(result.value, value, mpfr_rounding);
     return result;
   }
 
   MPFRNumber remquo(const MPFRNumber &divisor, int &quotient) {
-    MPFRNumber remainder;
+    MPFRNumber remainder(*this);
     long q;
-    mpfr_remquo(remainder.value, &q, value, divisor.value, MPFR_RNDN);
+    mpfr_remquo(remainder.value, &q, value, divisor.value, mpfr_rounding);
     quotient = q;
     return remainder;
   }
 
   MPFRNumber round() const {
-    MPFRNumber result;
+    MPFRNumber result(*this);
     mpfr_round(result.value, value);
     return result;
   }
 
-  bool roung_to_long(long &result) const {
+  bool round_to_long(long &result) const {
     // We first calculate the rounded value. This way, when converting
     // to long using mpfr_get_si, the rounding direction of MPFR_RNDN
     // (or any other rounding mode), does not have an influence.
@@ -199,14 +286,14 @@ class MPFRNumber {
     return mpfr_erangeflag_p();
   }
 
-  bool roung_to_long(mpfr_rnd_t rnd, long &result) const {
-    MPFRNumber rint_result;
+  bool round_to_long(mpfr_rnd_t rnd, long &result) const {
+    MPFRNumber rint_result(*this);
     mpfr_rint(rint_result.value, value, rnd);
-    return rint_result.roung_to_long(result);
+    return rint_result.round_to_long(result);
   }
 
   MPFRNumber rint(mpfr_rnd_t rnd) const {
-    MPFRNumber result;
+    MPFRNumber result(*this);
     mpfr_rint(result.value, value, rnd);
     return result;
   }
@@ -239,32 +326,32 @@ class MPFRNumber {
   }
 
   MPFRNumber sin() const {
-    MPFRNumber result;
-    mpfr_sin(result.value, value, MPFR_RNDN);
+    MPFRNumber result(*this);
+    mpfr_sin(result.value, value, mpfr_rounding);
     return result;
   }
 
   MPFRNumber sqrt() const {
-    MPFRNumber result;
-    mpfr_sqrt(result.value, value, MPFR_RNDN);
+    MPFRNumber result(*this);
+    mpfr_sqrt(result.value, value, mpfr_rounding);
     return result;
   }
 
   MPFRNumber tan() const {
-    MPFRNumber result;
-    mpfr_tan(result.value, value, MPFR_RNDN);
+    MPFRNumber result(*this);
+    mpfr_tan(result.value, value, mpfr_rounding);
     return result;
   }
 
   MPFRNumber trunc() const {
-    MPFRNumber result;
+    MPFRNumber result(*this);
     mpfr_trunc(result.value, value);
     return result;
   }
 
   MPFRNumber fma(const MPFRNumber &b, const MPFRNumber &c) {
     MPFRNumber result(*this);
-    mpfr_fma(result.value, value, b.value, c.value, MPFR_RNDN);
+    mpfr_fma(result.value, value, b.value, c.value, mpfr_rounding);
     return result;
   }
 
@@ -282,10 +369,14 @@ class MPFRNumber {
   // These functions are useful for debugging.
   template <typename T> T as() const;
 
-  template <> float as<float>() const { return mpfr_get_flt(value, MPFR_RNDN); }
-  template <> double as<double>() const { return mpfr_get_d(value, MPFR_RNDN); }
+  template <> float as<float>() const {
+    return mpfr_get_flt(value, mpfr_rounding);
+  }
+  template <> double as<double>() const {
+    return mpfr_get_d(value, mpfr_rounding);
+  }
   template <> long double as<long double>() const {
-    return mpfr_get_ld(value, MPFR_RNDN);
+    return mpfr_get_ld(value, mpfr_rounding);
   }
 
   void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); }
@@ -378,8 +469,9 @@ namespace internal {
 
 template <typename InputType>
 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
-unary_operation(Operation op, InputType input) {
-  MPFRNumber mpfrInput(input);
+unary_operation(Operation op, InputType input, unsigned int precision,
+                RoundingMode rounding) {
+  MPFRNumber mpfrInput(input, precision, rounding);
   switch (op) {
   case Operation::Abs:
     return mpfrInput.abs();
@@ -420,8 +512,9 @@ unary_operation(Operation op, InputType input) {
 
 template <typename InputType>
 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
-unary_operation_two_outputs(Operation op, InputType input, int &output) {
-  MPFRNumber mpfrInput(input);
+unary_operation_two_outputs(Operation op, InputType input, int &output,
+                            unsigned int precision, RoundingMode rounding) {
+  MPFRNumber mpfrInput(input, precision, rounding);
   switch (op) {
   case Operation::Frexp:
     return mpfrInput.frexp(output);
@@ -432,8 +525,10 @@ unary_operation_two_outputs(Operation op, InputType input, int &output) {
 
 template <typename InputType>
 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
-binary_operation_one_output(Operation op, InputType x, InputType y) {
-  MPFRNumber inputX(x), inputY(y);
+binary_operation_one_output(Operation op, InputType x, InputType y,
+                            unsigned int precision, RoundingMode rounding) {
+  MPFRNumber inputX(x, precision, rounding);
+  MPFRNumber inputY(y, precision, rounding);
   switch (op) {
   case Operation::Hypot:
     return inputX.hypot(inputY);
@@ -445,8 +540,10 @@ binary_operation_one_output(Operation op, InputType x, InputType y) {
 template <typename InputType>
 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
 binary_operation_two_outputs(Operation op, InputType x, InputType y,
-                             int &output) {
-  MPFRNumber inputX(x), inputY(y);
+                             int &output, unsigned int precision,
+                             RoundingMode rounding) {
+  MPFRNumber inputX(x, precision, rounding);
+  MPFRNumber inputY(y, precision, rounding);
   switch (op) {
   case Operation::RemQuo:
     return inputX.remquo(inputY, output);
@@ -458,12 +555,14 @@ binary_operation_two_outputs(Operation op, InputType x, InputType y,
 template <typename InputType>
 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
 ternary_operation_one_output(Operation op, InputType x, InputType y,
-                             InputType z) {
+                             InputType z, unsigned int precision,
+                             RoundingMode rounding) {
   // For FMA function, we just need to compare with the mpfr_fma with the same
   // precision as InputType.  Using higher precision as the intermediate results
   // to compare might incorrectly fail due to double-rounding errors.
-  constexpr unsigned int prec = Precision<InputType>::VALUE;
-  MPFRNumber inputX(x, prec), inputY(y, prec), inputZ(z, prec);
+  MPFRNumber inputX(x, precision, rounding);
+  MPFRNumber inputY(y, precision, rounding);
+  MPFRNumber inputZ(z, precision, rounding);
   switch (op) {
   case Operation::Fma:
     return inputX.fma(inputY, inputZ);
@@ -475,13 +574,14 @@ ternary_operation_one_output(Operation op, InputType x, InputType y,
 template <typename T>
 void explain_unary_operation_single_output_error(Operation op, T input,
                                                  T matchValue,
+                                                 double ulp_tolerance,
+                                                 RoundingMode rounding,
                                                  testutils::StreamWrapper &OS) {
-  MPFRNumber mpfrInput(input);
-  MPFRNumber mpfr_result = unary_operation(op, input);
+  unsigned int precision = get_precision<T>(ulp_tolerance);
+  MPFRNumber mpfrInput(input, precision);
+  MPFRNumber mpfr_result;
+  mpfr_result = unary_operation(op, input, precision, rounding);
   MPFRNumber mpfrMatchValue(matchValue);
-  FPBits<T> inputBits(input);
-  FPBits<T> matchBits(matchValue);
-  FPBits<T> mpfr_resultBits(mpfr_result.as<T>());
   OS << "Match value not within tolerance value of MPFR result:\n"
      << "  Input decimal: " << mpfrInput.str() << '\n';
   __llvm_libc::fputil::testing::describeValue("     Input bits: ", input, OS);
@@ -498,21 +598,24 @@ void explain_unary_operation_single_output_error(Operation op, T input,
 
 template void
 explain_unary_operation_single_output_error<float>(Operation op, float, float,
+                                                   double, RoundingMode,
                                                    testutils::StreamWrapper &);
 template void explain_unary_operation_single_output_error<double>(
-    Operation op, double, double, testutils::StreamWrapper &);
+    Operation op, double, double, double, RoundingMode,
+    testutils::StreamWrapper &);
 template void explain_unary_operation_single_output_error<long double>(
-    Operation op, long double, long double, testutils::StreamWrapper &);
+    Operation op, long double, long double, double, RoundingMode,
+    testutils::StreamWrapper &);
 
 template <typename T>
 void explain_unary_operation_two_outputs_error(
     Operation op, T input, const BinaryOutput<T> &libc_result,
-    testutils::StreamWrapper &OS) {
-  MPFRNumber mpfrInput(input);
-  FPBits<T> inputBits(input);
+    double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS) {
+  unsigned int precision = get_precision<T>(ulp_tolerance);
+  MPFRNumber mpfrInput(input, precision);
   int mpfrIntResult;
-  MPFRNumber mpfr_result =
-      unary_operation_two_outputs(op, input, mpfrIntResult);
+  MPFRNumber mpfr_result = unary_operation_two_outputs(op, input, mpfrIntResult,
+                                                       precision, rounding);
 
   if (mpfrIntResult != libc_result.i) {
     OS << "MPFR integral result: " << mpfrIntResult << '\n'
@@ -541,26 +644,26 @@ void explain_unary_operation_two_outputs_error(
 }
 
 template void explain_unary_operation_two_outputs_error<float>(
-    Operation, float, const BinaryOutput<float> &, testutils::StreamWrapper &);
-template void
-explain_unary_operation_two_outputs_error<double>(Operation, double,
-                                                  const BinaryOutput<double> &,
-                                                  testutils::StreamWrapper &);
-template void explain_unary_operation_two_outputs_error<long double>(
-    Operation, long double, const BinaryOutput<long double> &,
+    Operation, float, const BinaryOutput<float> &, double, RoundingMode,
     testutils::StreamWrapper &);
+template void explain_unary_operation_two_outputs_error<double>(
+    Operation, double, const BinaryOutput<double> &, double, RoundingMode,
+    testutils::StreamWrapper &);
+template void explain_unary_operation_two_outputs_error<long double>(
+    Operation, long double, const BinaryOutput<long double> &, double,
+    RoundingMode, testutils::StreamWrapper &);
 
 template <typename T>
 void explain_binary_operation_two_outputs_error(
     Operation op, const BinaryInput<T> &input,
-    const BinaryOutput<T> &libc_result, testutils::StreamWrapper &OS) {
-  MPFRNumber mpfrX(input.x);
-  MPFRNumber mpfrY(input.y);
-  FPBits<T> xbits(input.x);
-  FPBits<T> ybits(input.y);
+    const BinaryOutput<T> &libc_result, double ulp_tolerance,
+    RoundingMode rounding, testutils::StreamWrapper &OS) {
+  unsigned int precision = get_precision<T>(ulp_tolerance);
+  MPFRNumber mpfrX(input.x, precision);
+  MPFRNumber mpfrY(input.y, precision);
   int mpfrIntResult;
-  MPFRNumber mpfr_result =
-      binary_operation_two_outputs(op, input.x, input.y, mpfrIntResult);
+  MPFRNumber mpfr_result = binary_operation_two_outputs(
+      op, input.x, input.y, mpfrIntResult, precision, rounding);
   MPFRNumber mpfrMatchValue(libc_result.f);
 
   OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n'
@@ -576,25 +679,27 @@ void explain_binary_operation_two_outputs_error(
 }
 
 template void explain_binary_operation_two_outputs_error<float>(
-    Operation, const BinaryInput<float> &, const BinaryOutput<float> &,
-    testutils::StreamWrapper &);
+    Operation, const BinaryInput<float> &, const BinaryOutput<float> &, double,
+    RoundingMode, testutils::StreamWrapper &);
 template void explain_binary_operation_two_outputs_error<double>(
     Operation, const BinaryInput<double> &, const BinaryOutput<double> &,
-    testutils::StreamWrapper &);
+    double, RoundingMode, testutils::StreamWrapper &);
 template void explain_binary_operation_two_outputs_error<long double>(
     Operation, const BinaryInput<long double> &,
-    const BinaryOutput<long double> &, testutils::StreamWrapper &);
+    const BinaryOutput<long double> &, double, RoundingMode,
+    testutils::StreamWrapper &);
 
 template <typename T>
-void explain_binary_operation_one_output_error(Operation op,
-                                               const BinaryInput<T> &input,
-                                               T libc_result,
-                                               testutils::StreamWrapper &OS) {
-  MPFRNumber mpfrX(input.x);
-  MPFRNumber mpfrY(input.y);
+void explain_binary_operation_one_output_error(
+    Operation op, const BinaryInput<T> &input, T libc_result,
+    double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS) {
+  unsigned int precision = get_precision<T>(ulp_tolerance);
+  MPFRNumber mpfrX(input.x, precision);
+  MPFRNumber mpfrY(input.y, precision);
   FPBits<T> xbits(input.x);
   FPBits<T> ybits(input.y);
-  MPFRNumber mpfr_result = binary_operation_one_output(op, input.x, input.y);
+  MPFRNumber mpfr_result =
+      binary_operation_one_output(op, input.x, input.y, precision, rounding);
   MPFRNumber mpfrMatchValue(libc_result);
 
   OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n';
@@ -613,26 +718,28 @@ void explain_binary_operation_one_output_error(Operation op,
 }
 
 template void explain_binary_operation_one_output_error<float>(
-    Operation, const BinaryInput<float> &, float, testutils::StreamWrapper &);
+    Operation, const BinaryInput<float> &, float, double, RoundingMode,
+    testutils::StreamWrapper &);
 template void explain_binary_operation_one_output_error<double>(
-    Operation, const BinaryInput<double> &, double, testutils::StreamWrapper &);
-template void explain_binary_operation_one_output_error<long double>(
-    Operation, const BinaryInput<long double> &, long double,
+    Operation, const BinaryInput<double> &, double, double, RoundingMode,
     testutils::StreamWrapper &);
+template void explain_binary_operation_one_output_error<long double>(
+    Operation, const BinaryInput<long double> &, long double, double,
+    RoundingMode, testutils::StreamWrapper &);
 
 template <typename T>
-void explain_ternary_operation_one_output_error(Operation op,
-                                                const TernaryInput<T> &input,
-                                                T libc_result,
-                                                testutils::StreamWrapper &OS) {
-  MPFRNumber mpfrX(input.x, Precision<T>::VALUE);
-  MPFRNumber mpfrY(input.y, Precision<T>::VALUE);
-  MPFRNumber mpfrZ(input.z, Precision<T>::VALUE);
+void explain_ternary_operation_one_output_error(
+    Operation op, const TernaryInput<T> &input, T libc_result,
+    double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS) {
+  unsigned int precision = get_precision<T>(ulp_tolerance);
+  MPFRNumber mpfrX(input.x, precision);
+  MPFRNumber mpfrY(input.y, precision);
+  MPFRNumber mpfrZ(input.z, precision);
   FPBits<T> xbits(input.x);
   FPBits<T> ybits(input.y);
   FPBits<T> zbits(input.z);
-  MPFRNumber mpfr_result =
-      ternary_operation_one_output(op, input.x, input.y, input.z);
+  MPFRNumber mpfr_result = ternary_operation_one_output(
+      op, input.x, input.y, input.z, precision, rounding);
   MPFRNumber mpfrMatchValue(libc_result);
 
   OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str()
@@ -654,68 +761,70 @@ void explain_ternary_operation_one_output_error(Operation op,
 }
 
 template void explain_ternary_operation_one_output_error<float>(
-    Operation, const TernaryInput<float> &, float, testutils::StreamWrapper &);
+    Operation, const TernaryInput<float> &, float, double, RoundingMode,
+    testutils::StreamWrapper &);
 template void explain_ternary_operation_one_output_error<double>(
-    Operation, const TernaryInput<double> &, double,
+    Operation, const TernaryInput<double> &, double, double, RoundingMode,
     testutils::StreamWrapper &);
 template void explain_ternary_operation_one_output_error<long double>(
-    Operation, const TernaryInput<long double> &, long double,
-    testutils::StreamWrapper &);
+    Operation, const TernaryInput<long double> &, long double, double,
+    RoundingMode, testutils::StreamWrapper &);
 
 template <typename T>
 bool compare_unary_operation_single_output(Operation op, T input, T libc_result,
-                                           double ulp_error) {
-  // If the ulp error is exactly 0.5 (i.e a tie), we would check that the result
-  // is rounded to the nearest even.
-  MPFRNumber mpfr_result = unary_operation(op, input);
+                                           double ulp_tolerance,
+                                           RoundingMode rounding) {
+  unsigned int precision = get_precision<T>(ulp_tolerance);
+  MPFRNumber mpfr_result;
+  mpfr_result = unary_operation(op, input, precision, rounding);
   double ulp = mpfr_result.ulp(libc_result);
-  bool bits_are_even = ((FPBits<T>(libc_result).uintval() & 1) == 0);
-  return (ulp < ulp_error) ||
-         ((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even));
+  return (ulp <= ulp_tolerance);
 }
 
 template bool compare_unary_operation_single_output<float>(Operation, float,
-                                                           float, double);
+                                                           float, double,
+                                                           RoundingMode);
 template bool compare_unary_operation_single_output<double>(Operation, double,
-                                                            double, double);
-template bool compare_unary_operation_single_output<long double>(Operation,
-                                                                 long double,
-                                                                 long double,
-                                                                 double);
+                                                            double, double,
+                                                            RoundingMode);
+template bool compare_unary_operation_single_output<long double>(
+    Operation, long double, long double, double, RoundingMode);
 
 template <typename T>
 bool compare_unary_operation_two_outputs(Operation op, T input,
                                          const BinaryOutput<T> &libc_result,
-                                         double ulp_error) {
+                                         double ulp_tolerance,
+                                         RoundingMode rounding) {
   int mpfrIntResult;
-  MPFRNumber mpfr_result =
-      unary_operation_two_outputs(op, input, mpfrIntResult);
+  unsigned int precision = get_precision<T>(ulp_tolerance);
+  MPFRNumber mpfr_result = unary_operation_two_outputs(op, input, mpfrIntResult,
+                                                       precision, rounding);
   double ulp = mpfr_result.ulp(libc_result.f);
 
   if (mpfrIntResult != libc_result.i)
     return false;
 
-  bool bits_are_even = ((FPBits<T>(libc_result.f).uintval() & 1) == 0);
-  return (ulp < ulp_error) ||
-         ((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even));
+  return (ulp <= ulp_tolerance);
 }
 
-template bool
-compare_unary_operation_two_outputs<float>(Operation, float,
-                                           const BinaryOutput<float> &, double);
+template bool compare_unary_operation_two_outputs<float>(
+    Operation, float, const BinaryOutput<float> &, double, RoundingMode);
 template bool compare_unary_operation_two_outputs<double>(
-    Operation, double, const BinaryOutput<double> &, double);
+    Operation, double, const BinaryOutput<double> &, double, RoundingMode);
 template bool compare_unary_operation_two_outputs<long double>(
-    Operation, long double, const BinaryOutput<long double> &, double);
+    Operation, long double, const BinaryOutput<long double> &, double,
+    RoundingMode);
 
 template <typename T>
 bool compare_binary_operation_two_outputs(Operation op,
                                           const BinaryInput<T> &input,
                                           const BinaryOutput<T> &libc_result,
-                                          double ulp_error) {
+                                          double ulp_tolerance,
+                                          RoundingMode rounding) {
   int mpfrIntResult;
-  MPFRNumber mpfr_result =
-      binary_operation_two_outputs(op, input.x, input.y, mpfrIntResult);
+  unsigned int precision = get_precision<T>(ulp_tolerance);
+  MPFRNumber mpfr_result = binary_operation_two_outputs(
+      op, input.x, input.y, mpfrIntResult, precision, rounding);
   double ulp = mpfr_result.ulp(libc_result.f);
 
   if (mpfrIntResult != libc_result.i) {
@@ -727,81 +836,66 @@ bool compare_binary_operation_two_outputs(Operation op,
     }
   }
 
-  bool bits_are_even = ((FPBits<T>(libc_result.f).uintval() & 1) == 0);
-  return (ulp < ulp_error) ||
-         ((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even));
+  return (ulp <= ulp_tolerance);
 }
 
 template bool compare_binary_operation_two_outputs<float>(
-    Operation, const BinaryInput<float> &, const BinaryOutput<float> &, double);
+    Operation, const BinaryInput<float> &, const BinaryOutput<float> &, double,
+    RoundingMode);
 template bool compare_binary_operation_two_outputs<double>(
     Operation, const BinaryInput<double> &, const BinaryOutput<double> &,
-    double);
+    double, RoundingMode);
 template bool compare_binary_operation_two_outputs<long double>(
     Operation, const BinaryInput<long double> &,
-    const BinaryOutput<long double> &, double);
+    const BinaryOutput<long double> &, double, RoundingMode);
 
 template <typename T>
 bool compare_binary_operation_one_output(Operation op,
                                          const BinaryInput<T> &input,
-                                         T libc_result, double ulp_error) {
-  MPFRNumber mpfr_result = binary_operation_one_output(op, input.x, input.y);
+                                         T libc_result, double ulp_tolerance,
+                                         RoundingMode rounding) {
+  unsigned int precision = get_precision<T>(ulp_tolerance);
+  MPFRNumber mpfr_result =
+      binary_operation_one_output(op, input.x, input.y, precision, rounding);
   double ulp = mpfr_result.ulp(libc_result);
 
-  bool bits_are_even = ((FPBits<T>(libc_result).uintval() & 1) == 0);
-  return (ulp < ulp_error) ||
-         ((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even));
+  return (ulp <= ulp_tolerance);
 }
 
 template bool compare_binary_operation_one_output<float>(
-    Operation, const BinaryInput<float> &, float, double);
+    Operation, const BinaryInput<float> &, float, double, RoundingMode);
 template bool compare_binary_operation_one_output<double>(
-    Operation, const BinaryInput<double> &, double, double);
+    Operation, const BinaryInput<double> &, double, double, RoundingMode);
 template bool compare_binary_operation_one_output<long double>(
-    Operation, const BinaryInput<long double> &, long double, double);
+    Operation, const BinaryInput<long double> &, long double, double,
+    RoundingMode);
 
 template <typename T>
 bool compare_ternary_operation_one_output(Operation op,
                                           const TernaryInput<T> &input,
-                                          T libc_result, double ulp_error) {
-  MPFRNumber mpfr_result =
-      ternary_operation_one_output(op, input.x, input.y, input.z);
+                                          T libc_result, double ulp_tolerance,
+                                          RoundingMode rounding) {
+  unsigned int precision = get_precision<T>(ulp_tolerance);
+  MPFRNumber mpfr_result = ternary_operation_one_output(
+      op, input.x, input.y, input.z, precision, rounding);
   double ulp = mpfr_result.ulp(libc_result);
 
-  bool bits_are_even = ((FPBits<T>(libc_result).uintval() & 1) == 0);
-  return (ulp < ulp_error) ||
-         ((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even));
+  return (ulp <= ulp_tolerance);
 }
 
 template bool compare_ternary_operation_one_output<float>(
-    Operation, const TernaryInput<float> &, float, double);
+    Operation, const TernaryInput<float> &, float, double, RoundingMode);
 template bool compare_ternary_operation_one_output<double>(
-    Operation, const TernaryInput<double> &, double, double);
+    Operation, const TernaryInput<double> &, double, double, RoundingMode);
 template bool compare_ternary_operation_one_output<long double>(
-    Operation, const TernaryInput<long double> &, long double, double);
-
-static mpfr_rnd_t get_mpfr_rounding_mode(RoundingMode mode) {
-  switch (mode) {
-  case RoundingMode::Upward:
-    return MPFR_RNDU;
-    break;
-  case RoundingMode::Downward:
-    return MPFR_RNDD;
-    break;
-  case RoundingMode::TowardZero:
-    return MPFR_RNDZ;
-    break;
-  case RoundingMode::Nearest:
-    return MPFR_RNDN;
-    break;
-  }
-}
+    Operation, const TernaryInput<long double> &, long double, double,
+    RoundingMode);
 
 } // namespace internal
 
 template <typename T> bool round_to_long(T x, long &result) {
   MPFRNumber mpfr(x);
-  return mpfr.roung_to_long(result);
+  return mpfr.round_to_long(result);
 }
 
 template bool round_to_long<float>(float, long &);
@@ -810,7 +904,7 @@ template bool round_to_long<long double>(long double, long &);
 
 template <typename T> bool round_to_long(T x, RoundingMode mode, long &result) {
   MPFRNumber mpfr(x);
-  return mpfr.roung_to_long(internal::get_mpfr_rounding_mode(mode), result);
+  return mpfr.round_to_long(get_mpfr_rounding_mode(mode), result);
 }
 
 template bool round_to_long<float>(float, RoundingMode, long &);
@@ -819,7 +913,7 @@ template bool round_to_long<long double>(long double, RoundingMode, long &);
 
 template <typename T> T round(T x, RoundingMode mode) {
   MPFRNumber mpfr(x);
-  MPFRNumber result = mpfr.rint(internal::get_mpfr_rounding_mode(mode));
+  MPFRNumber result = mpfr.rint(get_mpfr_rounding_mode(mode));
   return result.as<T>();
 }
 

diff  --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h
index c52f2a76d9c6e..07c24149a5e3a 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.h
+++ b/libc/utils/MPFRWrapper/MPFRUtils.h
@@ -71,6 +71,18 @@ enum class Operation : int {
   EndTernaryOperationsSingleOutput,
 };
 
+enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest };
+
+int get_fe_rounding(RoundingMode mode);
+
+struct ForceRoundingMode {
+  ForceRoundingMode(RoundingMode);
+  ~ForceRoundingMode();
+
+  int old_rounding_mode;
+  int rounding_mode;
+};
+
 template <typename T> struct BinaryInput {
   static_assert(
       __llvm_libc::cpp::IsFloatingPointType<T>::Value,
@@ -108,65 +120,72 @@ struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
 
 template <typename T>
 bool compare_unary_operation_single_output(Operation op, T input, T libc_output,
-                                           double t);
+                                           double ulp_tolerance,
+                                           RoundingMode rounding);
 template <typename T>
 bool compare_unary_operation_two_outputs(Operation op, T input,
                                          const BinaryOutput<T> &libc_output,
-                                         double t);
+                                         double ulp_tolerance,
+                                         RoundingMode rounding);
 template <typename T>
 bool compare_binary_operation_two_outputs(Operation op,
                                           const BinaryInput<T> &input,
                                           const BinaryOutput<T> &libc_output,
-                                          double t);
+                                          double ulp_tolerance,
+                                          RoundingMode rounding);
 
 template <typename T>
 bool compare_binary_operation_one_output(Operation op,
                                          const BinaryInput<T> &input,
-                                         T libc_output, double t);
+                                         T libc_output, double ulp_tolerance,
+                                         RoundingMode rounding);
 
 template <typename T>
 bool compare_ternary_operation_one_output(Operation op,
                                           const TernaryInput<T> &input,
-                                          T libc_output, double t);
+                                          T libc_output, double ulp_tolerance,
+                                          RoundingMode rounding);
 
 template <typename T>
 void explain_unary_operation_single_output_error(Operation op, T input,
                                                  T match_value,
+                                                 double ulp_tolerance,
+                                                 RoundingMode rounding,
                                                  testutils::StreamWrapper &OS);
 template <typename T>
 void explain_unary_operation_two_outputs_error(
     Operation op, T input, const BinaryOutput<T> &match_value,
-    testutils::StreamWrapper &OS);
+    double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
 template <typename T>
 void explain_binary_operation_two_outputs_error(
     Operation op, const BinaryInput<T> &input,
-    const BinaryOutput<T> &match_value, testutils::StreamWrapper &OS);
+    const BinaryOutput<T> &match_value, double ulp_tolerance,
+    RoundingMode rounding, testutils::StreamWrapper &OS);
 
 template <typename T>
-void explain_binary_operation_one_output_error(Operation op,
-                                               const BinaryInput<T> &input,
-                                               T match_value,
-                                               testutils::StreamWrapper &OS);
+void explain_binary_operation_one_output_error(
+    Operation op, const BinaryInput<T> &input, T match_value,
+    double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
 
 template <typename T>
-void explain_ternary_operation_one_output_error(Operation op,
-                                                const TernaryInput<T> &input,
-                                                T match_value,
-                                                testutils::StreamWrapper &OS);
+void explain_ternary_operation_one_output_error(
+    Operation op, const TernaryInput<T> &input, T match_value,
+    double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
 
 template <Operation op, typename InputType, typename OutputType>
 class MPFRMatcher : public testing::Matcher<OutputType> {
   InputType input;
   OutputType match_value;
   double ulp_tolerance;
+  RoundingMode rounding;
 
 public:
-  MPFRMatcher(InputType testInput, double ulp_tolerance)
-      : input(testInput), ulp_tolerance(ulp_tolerance) {}
+  MPFRMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding)
+      : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {}
 
   bool match(OutputType libcResult) {
     match_value = libcResult;
-    return match(input, match_value, ulp_tolerance);
+    return match(input, match_value);
   }
 
   // This method is marked with NOLINT because it the name `explainError`
@@ -176,59 +195,64 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
   }
 
 private:
-  template <typename T> static bool match(T in, T out, double tolerance) {
-    return compare_unary_operation_single_output(op, in, out, tolerance);
+  template <typename T> bool match(T in, T out) {
+    return compare_unary_operation_single_output(op, in, out, ulp_tolerance,
+                                                 rounding);
   }
 
-  template <typename T>
-  static bool match(T in, const BinaryOutput<T> &out, double tolerance) {
-    return compare_unary_operation_two_outputs(op, in, out, tolerance);
+  template <typename T> bool match(T in, const BinaryOutput<T> &out) {
+    return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance,
+                                               rounding);
   }
 
-  template <typename T>
-  static bool match(const BinaryInput<T> &in, T out, double tolerance) {
-    return compare_binary_operation_one_output(op, in, out, tolerance);
+  template <typename T> bool match(const BinaryInput<T> &in, T out) {
+    return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
+                                               rounding);
   }
 
   template <typename T>
-  static bool match(BinaryInput<T> in, const BinaryOutput<T> &out,
-                    double tolerance) {
-    return compare_binary_operation_two_outputs(op, in, out, tolerance);
+  bool match(BinaryInput<T> in, const BinaryOutput<T> &out) {
+    return compare_binary_operation_two_outputs(op, in, out, ulp_tolerance,
+                                                rounding);
   }
 
-  template <typename T>
-  static bool match(const TernaryInput<T> &in, T out, double tolerance) {
-    return compare_ternary_operation_one_output(op, in, out, tolerance);
+  template <typename T> bool match(const TernaryInput<T> &in, T out) {
+    return compare_ternary_operation_one_output(op, in, out, ulp_tolerance,
+                                                rounding);
   }
 
   template <typename T>
-  static void explain_error(T in, T out, testutils::StreamWrapper &OS) {
-    explain_unary_operation_single_output_error(op, in, out, OS);
+  void explain_error(T in, T out, testutils::StreamWrapper &OS) {
+    explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
+                                                rounding, OS);
   }
 
   template <typename T>
-  static void explain_error(T in, const BinaryOutput<T> &out,
-                            testutils::StreamWrapper &OS) {
-    explain_unary_operation_two_outputs_error(op, in, out, OS);
+  void explain_error(T in, const BinaryOutput<T> &out,
+                     testutils::StreamWrapper &OS) {
+    explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance,
+                                              rounding, OS);
   }
 
   template <typename T>
-  static void explain_error(const BinaryInput<T> &in,
-                            const BinaryOutput<T> &out,
-                            testutils::StreamWrapper &OS) {
-    explain_binary_operation_two_outputs_error(op, in, out, OS);
+  void explain_error(const BinaryInput<T> &in, const BinaryOutput<T> &out,
+                     testutils::StreamWrapper &OS) {
+    explain_binary_operation_two_outputs_error(op, in, out, ulp_tolerance,
+                                               rounding, OS);
   }
 
   template <typename T>
-  static void explain_error(const BinaryInput<T> &in, T out,
-                            testutils::StreamWrapper &OS) {
-    explain_binary_operation_one_output_error(op, in, out, OS);
+  void explain_error(const BinaryInput<T> &in, T out,
+                     testutils::StreamWrapper &OS) {
+    explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
+                                              rounding, OS);
   }
 
   template <typename T>
-  static void explain_error(const TernaryInput<T> &in, T out,
-                            testutils::StreamWrapper &OS) {
-    explain_ternary_operation_one_output_error(op, in, out, OS);
+  void explain_error(const TernaryInput<T> &in, T out,
+                     testutils::StreamWrapper &OS) {
+    explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance,
+                                               rounding, OS);
   }
 };
 
@@ -264,12 +288,12 @@ template <Operation op, typename InputType, typename OutputType>
 __attribute__((no_sanitize("address")))
 cpp::EnableIfType<is_valid_operation<op, InputType, OutputType>(),
                   internal::MPFRMatcher<op, InputType, OutputType>>
-get_mpfr_matcher(InputType input, OutputType output_unused, double t) {
-  return internal::MPFRMatcher<op, InputType, OutputType>(input, t);
+get_mpfr_matcher(InputType input, OutputType output_unused,
+                 double ulp_tolerance, RoundingMode rounding) {
+  return internal::MPFRMatcher<op, InputType, OutputType>(input, ulp_tolerance,
+                                                          rounding);
 }
 
-enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest };
-
 template <typename T> T round(T x, RoundingMode mode);
 
 template <typename T> bool round_to_long(T x, long &result);
@@ -279,12 +303,42 @@ template <typename T> bool round_to_long(T x, RoundingMode mode, long &result);
 } // namespace testing
 } // namespace __llvm_libc
 
-#define EXPECT_MPFR_MATCH(op, input, match_value, tolerance)                   \
+// GET_MPFR_DUMMY_ARG is going to be added to the end of GET_MPFR_MACRO as a
+// simple way to avoid the compiler warning `gnu-zero-variadic-macro-arguments`.
+#define GET_MPFR_DUMMY_ARG(...) 0
+
+#define GET_MPFR_MACRO(__1, __2, __3, __4, __5, __NAME, ...) __NAME
+
+#define EXPECT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)       \
+  EXPECT_THAT(match_value,                                                     \
+              __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(                \
+                  input, match_value, ulp_tolerance,                           \
+                  __llvm_libc::testing::mpfr::RoundingMode::Nearest))
+
+#define EXPECT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,      \
+                                   rounding)                                   \
   EXPECT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(   \
-                               input, match_value, tolerance))
+                               input, match_value, ulp_tolerance, rounding))
+
+#define EXPECT_MPFR_MATCH(...)                                                 \
+  GET_MPFR_MACRO(__VA_ARGS__, EXPECT_MPFR_MATCH_ROUNDING,                      \
+                 EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG)                \
+  (__VA_ARGS__)
 
-#define ASSERT_MPFR_MATCH(op, input, match_value, tolerance)                   \
+#define ASSERT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)       \
+  ASSERT_THAT(match_value,                                                     \
+              __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(                \
+                  input, match_value, ulp_tolerance,                           \
+                  __llvm_libc::testing::mpfr::RoundingMode::Nearest))
+
+#define ASSERT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,      \
+                                   rounding)                                   \
   ASSERT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>(   \
-                               input, match_value, tolerance))
+                               input, match_value, ulp_tolerance, rounding))
+
+#define ASSERT_MPFR_MATCH(...)                                                 \
+  GET_MPFR_MACRO(__VA_ARGS__, ASSERT_MPFR_MATCH_ROUNDING,                      \
+                 ASSERT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG)                \
+  (__VA_ARGS__)
 
 #endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H


        


More information about the libc-commits mailing list