[libc-commits] [libc] cce6507 - [libc] Add rounding mode support for MPFR testing macros.
Tue Ly via libc-commits
libc-commits at lists.llvm.org
Thu Jan 13 10:29:11 PST 2022
Author: Tue Ly
Date: 2022-01-13T13:28:50-05:00
New Revision: cce650776722ba36d81941e0f2e31a9602c88682
URL: https://github.com/llvm/llvm-project/commit/cce650776722ba36d81941e0f2e31a9602c88682
DIFF: https://github.com/llvm/llvm-project/commit/cce650776722ba36d81941e0f2e31a9602c88682.diff
LOG: [libc] Add rounding mode support for MPFR testing macros.
Add an extra argument for rounding mode to EXPECT_MPFR_MATCH and ASSERT_MPFR_MATCH macros.
Reviewed By: sivachandra, michaelrj
Differential Revision: https://reviews.llvm.org/D116777
Added:
Modified:
libc/test/src/math/SqrtTest.h
libc/utils/MPFRWrapper/MPFRUtils.cpp
libc/utils/MPFRWrapper/MPFRUtils.h
Removed:
################################################################################
diff --git a/libc/test/src/math/SqrtTest.h b/libc/test/src/math/SqrtTest.h
index f3e4def723aaf..79306a60e2baa 100644
--- a/libc/test/src/math/SqrtTest.h
+++ b/libc/test/src/math/SqrtTest.h
@@ -1,4 +1,4 @@
-//===-- Utility class to test fabs[f|l] -------------------------*- C++ -*-===//
+//===-- Utility class to test sqrt[f|l] -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -24,7 +24,7 @@ template <typename T> class SqrtTest : public __llvm_libc::testing::Test {
public:
typedef T (*SqrtFunc)(T);
- void testSpecialNumbers(SqrtFunc func) {
+ void test_special_numbers(SqrtFunc func) {
ASSERT_FP_EQ(aNaN, func(aNaN));
ASSERT_FP_EQ(inf, func(inf));
ASSERT_FP_EQ(aNaN, func(neg_inf));
@@ -36,24 +36,23 @@ template <typename T> class SqrtTest : public __llvm_libc::testing::Test {
ASSERT_FP_EQ(T(3.0), func(T(9.0)));
}
- void testDenormalValues(SqrtFunc func) {
+ void test_denormal_values(SqrtFunc func) {
for (UIntType mant = 1; mant < HIDDEN_BIT; mant <<= 1) {
FPBits denormal(T(0.0));
denormal.set_mantissa(mant);
- ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, T(denormal), func(T(denormal)),
- T(0.5));
+ test_all_rounding_modes(func, T(denormal));
}
constexpr UIntType COUNT = 1'000'001;
constexpr UIntType STEP = HIDDEN_BIT / COUNT;
for (UIntType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
T x = *reinterpret_cast<T *>(&v);
- ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5);
+ test_all_rounding_modes(func, x);
}
}
- void testNormalRange(SqrtFunc func) {
+ void test_normal_range(SqrtFunc func) {
constexpr UIntType COUNT = 10'000'001;
constexpr UIntType STEP = UIntType(-1) / COUNT;
for (UIntType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
@@ -61,13 +60,31 @@ template <typename T> class SqrtTest : public __llvm_libc::testing::Test {
if (isnan(x) || (x < 0)) {
continue;
}
- ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5);
+ test_all_rounding_modes(func, x);
}
}
+
+ void test_all_rounding_modes(SqrtFunc func, T x) {
+ mpfr::ForceRoundingMode r1(mpfr::RoundingMode::Nearest);
+ EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5,
+ mpfr::RoundingMode::Nearest);
+
+ mpfr::ForceRoundingMode r2(mpfr::RoundingMode::Upward);
+ EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5,
+ mpfr::RoundingMode::Upward);
+
+ mpfr::ForceRoundingMode r3(mpfr::RoundingMode::Downward);
+ EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5,
+ mpfr::RoundingMode::Downward);
+
+ mpfr::ForceRoundingMode r4(mpfr::RoundingMode::TowardZero);
+ EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5,
+ mpfr::RoundingMode::TowardZero);
+ }
};
#define LIST_SQRT_TESTS(T, func) \
using LlvmLibcSqrtTest = SqrtTest<T>; \
- TEST_F(LlvmLibcSqrtTest, SpecialNumbers) { testSpecialNumbers(&func); } \
- TEST_F(LlvmLibcSqrtTest, DenormalValues) { testDenormalValues(&func); } \
- TEST_F(LlvmLibcSqrtTest, NormalRange) { testNormalRange(&func); }
+ TEST_F(LlvmLibcSqrtTest, SpecialNumbers) { test_special_numbers(&func); } \
+ TEST_F(LlvmLibcSqrtTest, DenormalValues) { test_denormal_values(&func); } \
+ TEST_F(LlvmLibcSqrtTest, NormalRange) { test_normal_range(&func); }
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp
index 4e25c874d122f..7bb3a7fc5f598 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp
@@ -14,6 +14,7 @@
#include "utils/UnitTest/FPMatcher.h"
#include <cmath>
+#include <fenv.h>
#include <memory>
#include <stdint.h>
#include <string>
@@ -55,141 +56,227 @@ template <> struct Precision<long double> {
};
#endif
+// A precision value which allows sufficiently large additional
+// precision compared to the floating point precision.
+template <typename T> struct ExtraPrecision;
+
+template <> struct ExtraPrecision<float> {
+ static constexpr unsigned int VALUE = 128;
+};
+
+template <> struct ExtraPrecision<double> {
+ static constexpr unsigned int VALUE = 256;
+};
+
+template <> struct ExtraPrecision<long double> {
+ static constexpr unsigned int VALUE = 256;
+};
+
+// If the ulp tolerance is less than or equal to 0.5, we would check that the
+// result is rounded correctly with respect to the rounding mode by using the
+// same precision as the inputs.
+template <typename T>
+static inline unsigned int get_precision(double ulp_tolerance) {
+ if (ulp_tolerance <= 0.5) {
+ return Precision<T>::VALUE;
+ } else {
+ return ExtraPrecision<T>::VALUE;
+ }
+}
+
+static inline mpfr_rnd_t get_mpfr_rounding_mode(RoundingMode mode) {
+ switch (mode) {
+ case RoundingMode::Upward:
+ return MPFR_RNDU;
+ break;
+ case RoundingMode::Downward:
+ return MPFR_RNDD;
+ break;
+ case RoundingMode::TowardZero:
+ return MPFR_RNDZ;
+ break;
+ case RoundingMode::Nearest:
+ return MPFR_RNDN;
+ break;
+ }
+}
+
+int get_fe_rounding(RoundingMode mode) {
+ switch (mode) {
+ case RoundingMode::Upward:
+ return FE_UPWARD;
+ break;
+ case RoundingMode::Downward:
+ return FE_DOWNWARD;
+ break;
+ case RoundingMode::TowardZero:
+ return FE_TOWARDZERO;
+ break;
+ case RoundingMode::Nearest:
+ return FE_TONEAREST;
+ break;
+ }
+}
+
+ForceRoundingMode::ForceRoundingMode(RoundingMode mode) {
+ old_rounding_mode = fegetround();
+ rounding_mode = get_fe_rounding(mode);
+ if (old_rounding_mode != rounding_mode)
+ fesetround(rounding_mode);
+}
+
+ForceRoundingMode::~ForceRoundingMode() {
+ if (old_rounding_mode != rounding_mode)
+ fesetround(old_rounding_mode);
+}
+
class MPFRNumber {
- // A precision value which allows sufficiently large additional
- // precision even compared to quad-precision floating point values.
unsigned int mpfr_precision;
+ mpfr_rnd_t mpfr_rounding;
mpfr_t value;
public:
- MPFRNumber() : mpfr_precision(256) { mpfr_init2(value, mpfr_precision); }
+ MPFRNumber() : mpfr_precision(256), mpfr_rounding(MPFR_RNDN) {
+ mpfr_init2(value, mpfr_precision);
+ }
// We use explicit EnableIf specializations to disallow implicit
// conversions. Implicit conversions can potentially lead to loss of
// precision.
template <typename XType,
cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0>
- explicit MPFRNumber(XType x, int precision = 128)
- : mpfr_precision(precision) {
+ explicit MPFRNumber(XType x, int precision = ExtraPrecision<XType>::VALUE,
+ RoundingMode rounding = RoundingMode::Nearest)
+ : mpfr_precision(precision),
+ mpfr_rounding(get_mpfr_rounding_mode(rounding)) {
mpfr_init2(value, mpfr_precision);
- mpfr_set_flt(value, x, MPFR_RNDN);
+ mpfr_set_flt(value, x, mpfr_rounding);
}
template <typename XType,
cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0>
- explicit MPFRNumber(XType x, int precision = 128)
- : mpfr_precision(precision) {
+ explicit MPFRNumber(XType x, int precision = ExtraPrecision<XType>::VALUE,
+ RoundingMode rounding = RoundingMode::Nearest)
+ : mpfr_precision(precision),
+ mpfr_rounding(get_mpfr_rounding_mode(rounding)) {
mpfr_init2(value, mpfr_precision);
- mpfr_set_d(value, x, MPFR_RNDN);
+ mpfr_set_d(value, x, mpfr_rounding);
}
template <typename XType,
cpp::EnableIfType<cpp::IsSame<long double, XType>::Value, int> = 0>
- explicit MPFRNumber(XType x, int precision = 128)
- : mpfr_precision(precision) {
+ explicit MPFRNumber(XType x, int precision = ExtraPrecision<XType>::VALUE,
+ RoundingMode rounding = RoundingMode::Nearest)
+ : mpfr_precision(precision),
+ mpfr_rounding(get_mpfr_rounding_mode(rounding)) {
mpfr_init2(value, mpfr_precision);
- mpfr_set_ld(value, x, MPFR_RNDN);
+ mpfr_set_ld(value, x, mpfr_rounding);
}
template <typename XType,
cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0>
- explicit MPFRNumber(XType x, int precision = 128)
- : mpfr_precision(precision) {
+ explicit MPFRNumber(XType x, int precision = ExtraPrecision<float>::VALUE,
+ RoundingMode rounding = RoundingMode::Nearest)
+ : mpfr_precision(precision),
+ mpfr_rounding(get_mpfr_rounding_mode(rounding)) {
mpfr_init2(value, mpfr_precision);
- mpfr_set_sj(value, x, MPFR_RNDN);
+ mpfr_set_sj(value, x, mpfr_rounding);
}
- MPFRNumber(const MPFRNumber &other) : mpfr_precision(other.mpfr_precision) {
+ MPFRNumber(const MPFRNumber &other)
+ : mpfr_precision(other.mpfr_precision),
+ mpfr_rounding(other.mpfr_rounding) {
mpfr_init2(value, mpfr_precision);
- mpfr_set(value, other.value, MPFR_RNDN);
+ mpfr_set(value, other.value, mpfr_rounding);
}
~MPFRNumber() { mpfr_clear(value); }
MPFRNumber &operator=(const MPFRNumber &rhs) {
mpfr_precision = rhs.mpfr_precision;
- mpfr_set(value, rhs.value, MPFR_RNDN);
+ mpfr_rounding = rhs.mpfr_rounding;
+ mpfr_set(value, rhs.value, mpfr_rounding);
return *this;
}
MPFRNumber abs() const {
- MPFRNumber result;
- mpfr_abs(result.value, value, MPFR_RNDN);
+ MPFRNumber result(*this);
+ mpfr_abs(result.value, value, mpfr_rounding);
return result;
}
MPFRNumber ceil() const {
- MPFRNumber result;
+ MPFRNumber result(*this);
mpfr_ceil(result.value, value);
return result;
}
MPFRNumber cos() const {
- MPFRNumber result;
- mpfr_cos(result.value, value, MPFR_RNDN);
+ MPFRNumber result(*this);
+ mpfr_cos(result.value, value, mpfr_rounding);
return result;
}
MPFRNumber exp() const {
- MPFRNumber result;
- mpfr_exp(result.value, value, MPFR_RNDN);
+ MPFRNumber result(*this);
+ mpfr_exp(result.value, value, mpfr_rounding);
return result;
}
MPFRNumber exp2() const {
- MPFRNumber result;
- mpfr_exp2(result.value, value, MPFR_RNDN);
+ MPFRNumber result(*this);
+ mpfr_exp2(result.value, value, mpfr_rounding);
return result;
}
MPFRNumber expm1() const {
- MPFRNumber result;
- mpfr_expm1(result.value, value, MPFR_RNDN);
+ MPFRNumber result(*this);
+ mpfr_expm1(result.value, value, mpfr_rounding);
return result;
}
MPFRNumber floor() const {
- MPFRNumber result;
+ MPFRNumber result(*this);
mpfr_floor(result.value, value);
return result;
}
MPFRNumber frexp(int &exp) {
- MPFRNumber result;
+ MPFRNumber result(*this);
mpfr_exp_t resultExp;
- mpfr_frexp(&resultExp, result.value, value, MPFR_RNDN);
+ mpfr_frexp(&resultExp, result.value, value, mpfr_rounding);
exp = resultExp;
return result;
}
MPFRNumber hypot(const MPFRNumber &b) {
- MPFRNumber result;
- mpfr_hypot(result.value, value, b.value, MPFR_RNDN);
+ MPFRNumber result(*this);
+ mpfr_hypot(result.value, value, b.value, mpfr_rounding);
return result;
}
MPFRNumber log() const {
- MPFRNumber result;
- mpfr_log(result.value, value, MPFR_RNDN);
+ MPFRNumber result(*this);
+ mpfr_log(result.value, value, mpfr_rounding);
return result;
}
MPFRNumber remquo(const MPFRNumber &divisor, int "ient) {
- MPFRNumber remainder;
+ MPFRNumber remainder(*this);
long q;
- mpfr_remquo(remainder.value, &q, value, divisor.value, MPFR_RNDN);
+ mpfr_remquo(remainder.value, &q, value, divisor.value, mpfr_rounding);
quotient = q;
return remainder;
}
MPFRNumber round() const {
- MPFRNumber result;
+ MPFRNumber result(*this);
mpfr_round(result.value, value);
return result;
}
- bool roung_to_long(long &result) const {
+ bool round_to_long(long &result) const {
// We first calculate the rounded value. This way, when converting
// to long using mpfr_get_si, the rounding direction of MPFR_RNDN
// (or any other rounding mode), does not have an influence.
@@ -199,14 +286,14 @@ class MPFRNumber {
return mpfr_erangeflag_p();
}
- bool roung_to_long(mpfr_rnd_t rnd, long &result) const {
- MPFRNumber rint_result;
+ bool round_to_long(mpfr_rnd_t rnd, long &result) const {
+ MPFRNumber rint_result(*this);
mpfr_rint(rint_result.value, value, rnd);
- return rint_result.roung_to_long(result);
+ return rint_result.round_to_long(result);
}
MPFRNumber rint(mpfr_rnd_t rnd) const {
- MPFRNumber result;
+ MPFRNumber result(*this);
mpfr_rint(result.value, value, rnd);
return result;
}
@@ -239,32 +326,32 @@ class MPFRNumber {
}
MPFRNumber sin() const {
- MPFRNumber result;
- mpfr_sin(result.value, value, MPFR_RNDN);
+ MPFRNumber result(*this);
+ mpfr_sin(result.value, value, mpfr_rounding);
return result;
}
MPFRNumber sqrt() const {
- MPFRNumber result;
- mpfr_sqrt(result.value, value, MPFR_RNDN);
+ MPFRNumber result(*this);
+ mpfr_sqrt(result.value, value, mpfr_rounding);
return result;
}
MPFRNumber tan() const {
- MPFRNumber result;
- mpfr_tan(result.value, value, MPFR_RNDN);
+ MPFRNumber result(*this);
+ mpfr_tan(result.value, value, mpfr_rounding);
return result;
}
MPFRNumber trunc() const {
- MPFRNumber result;
+ MPFRNumber result(*this);
mpfr_trunc(result.value, value);
return result;
}
MPFRNumber fma(const MPFRNumber &b, const MPFRNumber &c) {
MPFRNumber result(*this);
- mpfr_fma(result.value, value, b.value, c.value, MPFR_RNDN);
+ mpfr_fma(result.value, value, b.value, c.value, mpfr_rounding);
return result;
}
@@ -282,10 +369,14 @@ class MPFRNumber {
// These functions are useful for debugging.
template <typename T> T as() const;
- template <> float as<float>() const { return mpfr_get_flt(value, MPFR_RNDN); }
- template <> double as<double>() const { return mpfr_get_d(value, MPFR_RNDN); }
+ template <> float as<float>() const {
+ return mpfr_get_flt(value, mpfr_rounding);
+ }
+ template <> double as<double>() const {
+ return mpfr_get_d(value, mpfr_rounding);
+ }
template <> long double as<long double>() const {
- return mpfr_get_ld(value, MPFR_RNDN);
+ return mpfr_get_ld(value, mpfr_rounding);
}
void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); }
@@ -378,8 +469,9 @@ namespace internal {
template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
-unary_operation(Operation op, InputType input) {
- MPFRNumber mpfrInput(input);
+unary_operation(Operation op, InputType input, unsigned int precision,
+ RoundingMode rounding) {
+ MPFRNumber mpfrInput(input, precision, rounding);
switch (op) {
case Operation::Abs:
return mpfrInput.abs();
@@ -420,8 +512,9 @@ unary_operation(Operation op, InputType input) {
template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
-unary_operation_two_outputs(Operation op, InputType input, int &output) {
- MPFRNumber mpfrInput(input);
+unary_operation_two_outputs(Operation op, InputType input, int &output,
+ unsigned int precision, RoundingMode rounding) {
+ MPFRNumber mpfrInput(input, precision, rounding);
switch (op) {
case Operation::Frexp:
return mpfrInput.frexp(output);
@@ -432,8 +525,10 @@ unary_operation_two_outputs(Operation op, InputType input, int &output) {
template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
-binary_operation_one_output(Operation op, InputType x, InputType y) {
- MPFRNumber inputX(x), inputY(y);
+binary_operation_one_output(Operation op, InputType x, InputType y,
+ unsigned int precision, RoundingMode rounding) {
+ MPFRNumber inputX(x, precision, rounding);
+ MPFRNumber inputY(y, precision, rounding);
switch (op) {
case Operation::Hypot:
return inputX.hypot(inputY);
@@ -445,8 +540,10 @@ binary_operation_one_output(Operation op, InputType x, InputType y) {
template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
binary_operation_two_outputs(Operation op, InputType x, InputType y,
- int &output) {
- MPFRNumber inputX(x), inputY(y);
+ int &output, unsigned int precision,
+ RoundingMode rounding) {
+ MPFRNumber inputX(x, precision, rounding);
+ MPFRNumber inputY(y, precision, rounding);
switch (op) {
case Operation::RemQuo:
return inputX.remquo(inputY, output);
@@ -458,12 +555,14 @@ binary_operation_two_outputs(Operation op, InputType x, InputType y,
template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
ternary_operation_one_output(Operation op, InputType x, InputType y,
- InputType z) {
+ InputType z, unsigned int precision,
+ RoundingMode rounding) {
// For FMA function, we just need to compare with the mpfr_fma with the same
// precision as InputType. Using higher precision as the intermediate results
// to compare might incorrectly fail due to double-rounding errors.
- constexpr unsigned int prec = Precision<InputType>::VALUE;
- MPFRNumber inputX(x, prec), inputY(y, prec), inputZ(z, prec);
+ MPFRNumber inputX(x, precision, rounding);
+ MPFRNumber inputY(y, precision, rounding);
+ MPFRNumber inputZ(z, precision, rounding);
switch (op) {
case Operation::Fma:
return inputX.fma(inputY, inputZ);
@@ -475,13 +574,14 @@ ternary_operation_one_output(Operation op, InputType x, InputType y,
template <typename T>
void explain_unary_operation_single_output_error(Operation op, T input,
T matchValue,
+ double ulp_tolerance,
+ RoundingMode rounding,
testutils::StreamWrapper &OS) {
- MPFRNumber mpfrInput(input);
- MPFRNumber mpfr_result = unary_operation(op, input);
+ unsigned int precision = get_precision<T>(ulp_tolerance);
+ MPFRNumber mpfrInput(input, precision);
+ MPFRNumber mpfr_result;
+ mpfr_result = unary_operation(op, input, precision, rounding);
MPFRNumber mpfrMatchValue(matchValue);
- FPBits<T> inputBits(input);
- FPBits<T> matchBits(matchValue);
- FPBits<T> mpfr_resultBits(mpfr_result.as<T>());
OS << "Match value not within tolerance value of MPFR result:\n"
<< " Input decimal: " << mpfrInput.str() << '\n';
__llvm_libc::fputil::testing::describeValue(" Input bits: ", input, OS);
@@ -498,21 +598,24 @@ void explain_unary_operation_single_output_error(Operation op, T input,
template void
explain_unary_operation_single_output_error<float>(Operation op, float, float,
+ double, RoundingMode,
testutils::StreamWrapper &);
template void explain_unary_operation_single_output_error<double>(
- Operation op, double, double, testutils::StreamWrapper &);
+ Operation op, double, double, double, RoundingMode,
+ testutils::StreamWrapper &);
template void explain_unary_operation_single_output_error<long double>(
- Operation op, long double, long double, testutils::StreamWrapper &);
+ Operation op, long double, long double, double, RoundingMode,
+ testutils::StreamWrapper &);
template <typename T>
void explain_unary_operation_two_outputs_error(
Operation op, T input, const BinaryOutput<T> &libc_result,
- testutils::StreamWrapper &OS) {
- MPFRNumber mpfrInput(input);
- FPBits<T> inputBits(input);
+ double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS) {
+ unsigned int precision = get_precision<T>(ulp_tolerance);
+ MPFRNumber mpfrInput(input, precision);
int mpfrIntResult;
- MPFRNumber mpfr_result =
- unary_operation_two_outputs(op, input, mpfrIntResult);
+ MPFRNumber mpfr_result = unary_operation_two_outputs(op, input, mpfrIntResult,
+ precision, rounding);
if (mpfrIntResult != libc_result.i) {
OS << "MPFR integral result: " << mpfrIntResult << '\n'
@@ -541,26 +644,26 @@ void explain_unary_operation_two_outputs_error(
}
template void explain_unary_operation_two_outputs_error<float>(
- Operation, float, const BinaryOutput<float> &, testutils::StreamWrapper &);
-template void
-explain_unary_operation_two_outputs_error<double>(Operation, double,
- const BinaryOutput<double> &,
- testutils::StreamWrapper &);
-template void explain_unary_operation_two_outputs_error<long double>(
- Operation, long double, const BinaryOutput<long double> &,
+ Operation, float, const BinaryOutput<float> &, double, RoundingMode,
testutils::StreamWrapper &);
+template void explain_unary_operation_two_outputs_error<double>(
+ Operation, double, const BinaryOutput<double> &, double, RoundingMode,
+ testutils::StreamWrapper &);
+template void explain_unary_operation_two_outputs_error<long double>(
+ Operation, long double, const BinaryOutput<long double> &, double,
+ RoundingMode, testutils::StreamWrapper &);
template <typename T>
void explain_binary_operation_two_outputs_error(
Operation op, const BinaryInput<T> &input,
- const BinaryOutput<T> &libc_result, testutils::StreamWrapper &OS) {
- MPFRNumber mpfrX(input.x);
- MPFRNumber mpfrY(input.y);
- FPBits<T> xbits(input.x);
- FPBits<T> ybits(input.y);
+ const BinaryOutput<T> &libc_result, double ulp_tolerance,
+ RoundingMode rounding, testutils::StreamWrapper &OS) {
+ unsigned int precision = get_precision<T>(ulp_tolerance);
+ MPFRNumber mpfrX(input.x, precision);
+ MPFRNumber mpfrY(input.y, precision);
int mpfrIntResult;
- MPFRNumber mpfr_result =
- binary_operation_two_outputs(op, input.x, input.y, mpfrIntResult);
+ MPFRNumber mpfr_result = binary_operation_two_outputs(
+ op, input.x, input.y, mpfrIntResult, precision, rounding);
MPFRNumber mpfrMatchValue(libc_result.f);
OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n'
@@ -576,25 +679,27 @@ void explain_binary_operation_two_outputs_error(
}
template void explain_binary_operation_two_outputs_error<float>(
- Operation, const BinaryInput<float> &, const BinaryOutput<float> &,
- testutils::StreamWrapper &);
+ Operation, const BinaryInput<float> &, const BinaryOutput<float> &, double,
+ RoundingMode, testutils::StreamWrapper &);
template void explain_binary_operation_two_outputs_error<double>(
Operation, const BinaryInput<double> &, const BinaryOutput<double> &,
- testutils::StreamWrapper &);
+ double, RoundingMode, testutils::StreamWrapper &);
template void explain_binary_operation_two_outputs_error<long double>(
Operation, const BinaryInput<long double> &,
- const BinaryOutput<long double> &, testutils::StreamWrapper &);
+ const BinaryOutput<long double> &, double, RoundingMode,
+ testutils::StreamWrapper &);
template <typename T>
-void explain_binary_operation_one_output_error(Operation op,
- const BinaryInput<T> &input,
- T libc_result,
- testutils::StreamWrapper &OS) {
- MPFRNumber mpfrX(input.x);
- MPFRNumber mpfrY(input.y);
+void explain_binary_operation_one_output_error(
+ Operation op, const BinaryInput<T> &input, T libc_result,
+ double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS) {
+ unsigned int precision = get_precision<T>(ulp_tolerance);
+ MPFRNumber mpfrX(input.x, precision);
+ MPFRNumber mpfrY(input.y, precision);
FPBits<T> xbits(input.x);
FPBits<T> ybits(input.y);
- MPFRNumber mpfr_result = binary_operation_one_output(op, input.x, input.y);
+ MPFRNumber mpfr_result =
+ binary_operation_one_output(op, input.x, input.y, precision, rounding);
MPFRNumber mpfrMatchValue(libc_result);
OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n';
@@ -613,26 +718,28 @@ void explain_binary_operation_one_output_error(Operation op,
}
template void explain_binary_operation_one_output_error<float>(
- Operation, const BinaryInput<float> &, float, testutils::StreamWrapper &);
+ Operation, const BinaryInput<float> &, float, double, RoundingMode,
+ testutils::StreamWrapper &);
template void explain_binary_operation_one_output_error<double>(
- Operation, const BinaryInput<double> &, double, testutils::StreamWrapper &);
-template void explain_binary_operation_one_output_error<long double>(
- Operation, const BinaryInput<long double> &, long double,
+ Operation, const BinaryInput<double> &, double, double, RoundingMode,
testutils::StreamWrapper &);
+template void explain_binary_operation_one_output_error<long double>(
+ Operation, const BinaryInput<long double> &, long double, double,
+ RoundingMode, testutils::StreamWrapper &);
template <typename T>
-void explain_ternary_operation_one_output_error(Operation op,
- const TernaryInput<T> &input,
- T libc_result,
- testutils::StreamWrapper &OS) {
- MPFRNumber mpfrX(input.x, Precision<T>::VALUE);
- MPFRNumber mpfrY(input.y, Precision<T>::VALUE);
- MPFRNumber mpfrZ(input.z, Precision<T>::VALUE);
+void explain_ternary_operation_one_output_error(
+ Operation op, const TernaryInput<T> &input, T libc_result,
+ double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS) {
+ unsigned int precision = get_precision<T>(ulp_tolerance);
+ MPFRNumber mpfrX(input.x, precision);
+ MPFRNumber mpfrY(input.y, precision);
+ MPFRNumber mpfrZ(input.z, precision);
FPBits<T> xbits(input.x);
FPBits<T> ybits(input.y);
FPBits<T> zbits(input.z);
- MPFRNumber mpfr_result =
- ternary_operation_one_output(op, input.x, input.y, input.z);
+ MPFRNumber mpfr_result = ternary_operation_one_output(
+ op, input.x, input.y, input.z, precision, rounding);
MPFRNumber mpfrMatchValue(libc_result);
OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str()
@@ -654,68 +761,70 @@ void explain_ternary_operation_one_output_error(Operation op,
}
template void explain_ternary_operation_one_output_error<float>(
- Operation, const TernaryInput<float> &, float, testutils::StreamWrapper &);
+ Operation, const TernaryInput<float> &, float, double, RoundingMode,
+ testutils::StreamWrapper &);
template void explain_ternary_operation_one_output_error<double>(
- Operation, const TernaryInput<double> &, double,
+ Operation, const TernaryInput<double> &, double, double, RoundingMode,
testutils::StreamWrapper &);
template void explain_ternary_operation_one_output_error<long double>(
- Operation, const TernaryInput<long double> &, long double,
- testutils::StreamWrapper &);
+ Operation, const TernaryInput<long double> &, long double, double,
+ RoundingMode, testutils::StreamWrapper &);
template <typename T>
bool compare_unary_operation_single_output(Operation op, T input, T libc_result,
- double ulp_error) {
- // If the ulp error is exactly 0.5 (i.e a tie), we would check that the result
- // is rounded to the nearest even.
- MPFRNumber mpfr_result = unary_operation(op, input);
+ double ulp_tolerance,
+ RoundingMode rounding) {
+ unsigned int precision = get_precision<T>(ulp_tolerance);
+ MPFRNumber mpfr_result;
+ mpfr_result = unary_operation(op, input, precision, rounding);
double ulp = mpfr_result.ulp(libc_result);
- bool bits_are_even = ((FPBits<T>(libc_result).uintval() & 1) == 0);
- return (ulp < ulp_error) ||
- ((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even));
+ return (ulp <= ulp_tolerance);
}
template bool compare_unary_operation_single_output<float>(Operation, float,
- float, double);
+ float, double,
+ RoundingMode);
template bool compare_unary_operation_single_output<double>(Operation, double,
- double, double);
-template bool compare_unary_operation_single_output<long double>(Operation,
- long double,
- long double,
- double);
+ double, double,
+ RoundingMode);
+template bool compare_unary_operation_single_output<long double>(
+ Operation, long double, long double, double, RoundingMode);
template <typename T>
bool compare_unary_operation_two_outputs(Operation op, T input,
const BinaryOutput<T> &libc_result,
- double ulp_error) {
+ double ulp_tolerance,
+ RoundingMode rounding) {
int mpfrIntResult;
- MPFRNumber mpfr_result =
- unary_operation_two_outputs(op, input, mpfrIntResult);
+ unsigned int precision = get_precision<T>(ulp_tolerance);
+ MPFRNumber mpfr_result = unary_operation_two_outputs(op, input, mpfrIntResult,
+ precision, rounding);
double ulp = mpfr_result.ulp(libc_result.f);
if (mpfrIntResult != libc_result.i)
return false;
- bool bits_are_even = ((FPBits<T>(libc_result.f).uintval() & 1) == 0);
- return (ulp < ulp_error) ||
- ((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even));
+ return (ulp <= ulp_tolerance);
}
-template bool
-compare_unary_operation_two_outputs<float>(Operation, float,
- const BinaryOutput<float> &, double);
+template bool compare_unary_operation_two_outputs<float>(
+ Operation, float, const BinaryOutput<float> &, double, RoundingMode);
template bool compare_unary_operation_two_outputs<double>(
- Operation, double, const BinaryOutput<double> &, double);
+ Operation, double, const BinaryOutput<double> &, double, RoundingMode);
template bool compare_unary_operation_two_outputs<long double>(
- Operation, long double, const BinaryOutput<long double> &, double);
+ Operation, long double, const BinaryOutput<long double> &, double,
+ RoundingMode);
template <typename T>
bool compare_binary_operation_two_outputs(Operation op,
const BinaryInput<T> &input,
const BinaryOutput<T> &libc_result,
- double ulp_error) {
+ double ulp_tolerance,
+ RoundingMode rounding) {
int mpfrIntResult;
- MPFRNumber mpfr_result =
- binary_operation_two_outputs(op, input.x, input.y, mpfrIntResult);
+ unsigned int precision = get_precision<T>(ulp_tolerance);
+ MPFRNumber mpfr_result = binary_operation_two_outputs(
+ op, input.x, input.y, mpfrIntResult, precision, rounding);
double ulp = mpfr_result.ulp(libc_result.f);
if (mpfrIntResult != libc_result.i) {
@@ -727,81 +836,66 @@ bool compare_binary_operation_two_outputs(Operation op,
}
}
- bool bits_are_even = ((FPBits<T>(libc_result.f).uintval() & 1) == 0);
- return (ulp < ulp_error) ||
- ((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even));
+ return (ulp <= ulp_tolerance);
}
template bool compare_binary_operation_two_outputs<float>(
- Operation, const BinaryInput<float> &, const BinaryOutput<float> &, double);
+ Operation, const BinaryInput<float> &, const BinaryOutput<float> &, double,
+ RoundingMode);
template bool compare_binary_operation_two_outputs<double>(
Operation, const BinaryInput<double> &, const BinaryOutput<double> &,
- double);
+ double, RoundingMode);
template bool compare_binary_operation_two_outputs<long double>(
Operation, const BinaryInput<long double> &,
- const BinaryOutput<long double> &, double);
+ const BinaryOutput<long double> &, double, RoundingMode);
template <typename T>
bool compare_binary_operation_one_output(Operation op,
const BinaryInput<T> &input,
- T libc_result, double ulp_error) {
- MPFRNumber mpfr_result = binary_operation_one_output(op, input.x, input.y);
+ T libc_result, double ulp_tolerance,
+ RoundingMode rounding) {
+ unsigned int precision = get_precision<T>(ulp_tolerance);
+ MPFRNumber mpfr_result =
+ binary_operation_one_output(op, input.x, input.y, precision, rounding);
double ulp = mpfr_result.ulp(libc_result);
- bool bits_are_even = ((FPBits<T>(libc_result).uintval() & 1) == 0);
- return (ulp < ulp_error) ||
- ((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even));
+ return (ulp <= ulp_tolerance);
}
template bool compare_binary_operation_one_output<float>(
- Operation, const BinaryInput<float> &, float, double);
+ Operation, const BinaryInput<float> &, float, double, RoundingMode);
template bool compare_binary_operation_one_output<double>(
- Operation, const BinaryInput<double> &, double, double);
+ Operation, const BinaryInput<double> &, double, double, RoundingMode);
template bool compare_binary_operation_one_output<long double>(
- Operation, const BinaryInput<long double> &, long double, double);
+ Operation, const BinaryInput<long double> &, long double, double,
+ RoundingMode);
template <typename T>
bool compare_ternary_operation_one_output(Operation op,
const TernaryInput<T> &input,
- T libc_result, double ulp_error) {
- MPFRNumber mpfr_result =
- ternary_operation_one_output(op, input.x, input.y, input.z);
+ T libc_result, double ulp_tolerance,
+ RoundingMode rounding) {
+ unsigned int precision = get_precision<T>(ulp_tolerance);
+ MPFRNumber mpfr_result = ternary_operation_one_output(
+ op, input.x, input.y, input.z, precision, rounding);
double ulp = mpfr_result.ulp(libc_result);
- bool bits_are_even = ((FPBits<T>(libc_result).uintval() & 1) == 0);
- return (ulp < ulp_error) ||
- ((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even));
+ return (ulp <= ulp_tolerance);
}
template bool compare_ternary_operation_one_output<float>(
- Operation, const TernaryInput<float> &, float, double);
+ Operation, const TernaryInput<float> &, float, double, RoundingMode);
template bool compare_ternary_operation_one_output<double>(
- Operation, const TernaryInput<double> &, double, double);
+ Operation, const TernaryInput<double> &, double, double, RoundingMode);
template bool compare_ternary_operation_one_output<long double>(
- Operation, const TernaryInput<long double> &, long double, double);
-
-static mpfr_rnd_t get_mpfr_rounding_mode(RoundingMode mode) {
- switch (mode) {
- case RoundingMode::Upward:
- return MPFR_RNDU;
- break;
- case RoundingMode::Downward:
- return MPFR_RNDD;
- break;
- case RoundingMode::TowardZero:
- return MPFR_RNDZ;
- break;
- case RoundingMode::Nearest:
- return MPFR_RNDN;
- break;
- }
-}
+ Operation, const TernaryInput<long double> &, long double, double,
+ RoundingMode);
} // namespace internal
template <typename T> bool round_to_long(T x, long &result) {
MPFRNumber mpfr(x);
- return mpfr.roung_to_long(result);
+ return mpfr.round_to_long(result);
}
template bool round_to_long<float>(float, long &);
@@ -810,7 +904,7 @@ template bool round_to_long<long double>(long double, long &);
template <typename T> bool round_to_long(T x, RoundingMode mode, long &result) {
MPFRNumber mpfr(x);
- return mpfr.roung_to_long(internal::get_mpfr_rounding_mode(mode), result);
+ return mpfr.round_to_long(get_mpfr_rounding_mode(mode), result);
}
template bool round_to_long<float>(float, RoundingMode, long &);
@@ -819,7 +913,7 @@ template bool round_to_long<long double>(long double, RoundingMode, long &);
template <typename T> T round(T x, RoundingMode mode) {
MPFRNumber mpfr(x);
- MPFRNumber result = mpfr.rint(internal::get_mpfr_rounding_mode(mode));
+ MPFRNumber result = mpfr.rint(get_mpfr_rounding_mode(mode));
return result.as<T>();
}
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h
index c52f2a76d9c6e..07c24149a5e3a 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.h
+++ b/libc/utils/MPFRWrapper/MPFRUtils.h
@@ -71,6 +71,18 @@ enum class Operation : int {
EndTernaryOperationsSingleOutput,
};
+enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest };
+
+int get_fe_rounding(RoundingMode mode);
+
+struct ForceRoundingMode {
+ ForceRoundingMode(RoundingMode);
+ ~ForceRoundingMode();
+
+ int old_rounding_mode;
+ int rounding_mode;
+};
+
template <typename T> struct BinaryInput {
static_assert(
__llvm_libc::cpp::IsFloatingPointType<T>::Value,
@@ -108,65 +120,72 @@ struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
template <typename T>
bool compare_unary_operation_single_output(Operation op, T input, T libc_output,
- double t);
+ double ulp_tolerance,
+ RoundingMode rounding);
template <typename T>
bool compare_unary_operation_two_outputs(Operation op, T input,
const BinaryOutput<T> &libc_output,
- double t);
+ double ulp_tolerance,
+ RoundingMode rounding);
template <typename T>
bool compare_binary_operation_two_outputs(Operation op,
const BinaryInput<T> &input,
const BinaryOutput<T> &libc_output,
- double t);
+ double ulp_tolerance,
+ RoundingMode rounding);
template <typename T>
bool compare_binary_operation_one_output(Operation op,
const BinaryInput<T> &input,
- T libc_output, double t);
+ T libc_output, double ulp_tolerance,
+ RoundingMode rounding);
template <typename T>
bool compare_ternary_operation_one_output(Operation op,
const TernaryInput<T> &input,
- T libc_output, double t);
+ T libc_output, double ulp_tolerance,
+ RoundingMode rounding);
template <typename T>
void explain_unary_operation_single_output_error(Operation op, T input,
T match_value,
+ double ulp_tolerance,
+ RoundingMode rounding,
testutils::StreamWrapper &OS);
template <typename T>
void explain_unary_operation_two_outputs_error(
Operation op, T input, const BinaryOutput<T> &match_value,
- testutils::StreamWrapper &OS);
+ double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
template <typename T>
void explain_binary_operation_two_outputs_error(
Operation op, const BinaryInput<T> &input,
- const BinaryOutput<T> &match_value, testutils::StreamWrapper &OS);
+ const BinaryOutput<T> &match_value, double ulp_tolerance,
+ RoundingMode rounding, testutils::StreamWrapper &OS);
template <typename T>
-void explain_binary_operation_one_output_error(Operation op,
- const BinaryInput<T> &input,
- T match_value,
- testutils::StreamWrapper &OS);
+void explain_binary_operation_one_output_error(
+ Operation op, const BinaryInput<T> &input, T match_value,
+ double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
template <typename T>
-void explain_ternary_operation_one_output_error(Operation op,
- const TernaryInput<T> &input,
- T match_value,
- testutils::StreamWrapper &OS);
+void explain_ternary_operation_one_output_error(
+ Operation op, const TernaryInput<T> &input, T match_value,
+ double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
template <Operation op, typename InputType, typename OutputType>
class MPFRMatcher : public testing::Matcher<OutputType> {
InputType input;
OutputType match_value;
double ulp_tolerance;
+ RoundingMode rounding;
public:
- MPFRMatcher(InputType testInput, double ulp_tolerance)
- : input(testInput), ulp_tolerance(ulp_tolerance) {}
+ MPFRMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding)
+ : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {}
bool match(OutputType libcResult) {
match_value = libcResult;
- return match(input, match_value, ulp_tolerance);
+ return match(input, match_value);
}
// This method is marked with NOLINT because it the name `explainError`
@@ -176,59 +195,64 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
}
private:
- template <typename T> static bool match(T in, T out, double tolerance) {
- return compare_unary_operation_single_output(op, in, out, tolerance);
+ template <typename T> bool match(T in, T out) {
+ return compare_unary_operation_single_output(op, in, out, ulp_tolerance,
+ rounding);
}
- template <typename T>
- static bool match(T in, const BinaryOutput<T> &out, double tolerance) {
- return compare_unary_operation_two_outputs(op, in, out, tolerance);
+ template <typename T> bool match(T in, const BinaryOutput<T> &out) {
+ return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance,
+ rounding);
}
- template <typename T>
- static bool match(const BinaryInput<T> &in, T out, double tolerance) {
- return compare_binary_operation_one_output(op, in, out, tolerance);
+ template <typename T> bool match(const BinaryInput<T> &in, T out) {
+ return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
+ rounding);
}
template <typename T>
- static bool match(BinaryInput<T> in, const BinaryOutput<T> &out,
- double tolerance) {
- return compare_binary_operation_two_outputs(op, in, out, tolerance);
+ bool match(BinaryInput<T> in, const BinaryOutput<T> &out) {
+ return compare_binary_operation_two_outputs(op, in, out, ulp_tolerance,
+ rounding);
}
- template <typename T>
- static bool match(const TernaryInput<T> &in, T out, double tolerance) {
- return compare_ternary_operation_one_output(op, in, out, tolerance);
+ template <typename T> bool match(const TernaryInput<T> &in, T out) {
+ return compare_ternary_operation_one_output(op, in, out, ulp_tolerance,
+ rounding);
}
template <typename T>
- static void explain_error(T in, T out, testutils::StreamWrapper &OS) {
- explain_unary_operation_single_output_error(op, in, out, OS);
+ void explain_error(T in, T out, testutils::StreamWrapper &OS) {
+ explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
+ rounding, OS);
}
template <typename T>
- static void explain_error(T in, const BinaryOutput<T> &out,
- testutils::StreamWrapper &OS) {
- explain_unary_operation_two_outputs_error(op, in, out, OS);
+ void explain_error(T in, const BinaryOutput<T> &out,
+ testutils::StreamWrapper &OS) {
+ explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance,
+ rounding, OS);
}
template <typename T>
- static void explain_error(const BinaryInput<T> &in,
- const BinaryOutput<T> &out,
- testutils::StreamWrapper &OS) {
- explain_binary_operation_two_outputs_error(op, in, out, OS);
+ void explain_error(const BinaryInput<T> &in, const BinaryOutput<T> &out,
+ testutils::StreamWrapper &OS) {
+ explain_binary_operation_two_outputs_error(op, in, out, ulp_tolerance,
+ rounding, OS);
}
template <typename T>
- static void explain_error(const BinaryInput<T> &in, T out,
- testutils::StreamWrapper &OS) {
- explain_binary_operation_one_output_error(op, in, out, OS);
+ void explain_error(const BinaryInput<T> &in, T out,
+ testutils::StreamWrapper &OS) {
+ explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
+ rounding, OS);
}
template <typename T>
- static void explain_error(const TernaryInput<T> &in, T out,
- testutils::StreamWrapper &OS) {
- explain_ternary_operation_one_output_error(op, in, out, OS);
+ void explain_error(const TernaryInput<T> &in, T out,
+ testutils::StreamWrapper &OS) {
+ explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance,
+ rounding, OS);
}
};
@@ -264,12 +288,12 @@ template <Operation op, typename InputType, typename OutputType>
__attribute__((no_sanitize("address")))
cpp::EnableIfType<is_valid_operation<op, InputType, OutputType>(),
internal::MPFRMatcher<op, InputType, OutputType>>
-get_mpfr_matcher(InputType input, OutputType output_unused, double t) {
- return internal::MPFRMatcher<op, InputType, OutputType>(input, t);
+get_mpfr_matcher(InputType input, OutputType output_unused,
+ double ulp_tolerance, RoundingMode rounding) {
+ return internal::MPFRMatcher<op, InputType, OutputType>(input, ulp_tolerance,
+ rounding);
}
-enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest };
-
template <typename T> T round(T x, RoundingMode mode);
template <typename T> bool round_to_long(T x, long &result);
@@ -279,12 +303,42 @@ template <typename T> bool round_to_long(T x, RoundingMode mode, long &result);
} // namespace testing
} // namespace __llvm_libc
-#define EXPECT_MPFR_MATCH(op, input, match_value, tolerance) \
+// GET_MPFR_DUMMY_ARG is going to be added to the end of GET_MPFR_MACRO as a
+// simple way to avoid the compiler warning `gnu-zero-variadic-macro-arguments`.
+#define GET_MPFR_DUMMY_ARG(...) 0
+
+#define GET_MPFR_MACRO(__1, __2, __3, __4, __5, __NAME, ...) __NAME
+
+#define EXPECT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \
+ EXPECT_THAT(match_value, \
+ __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \
+ input, match_value, ulp_tolerance, \
+ __llvm_libc::testing::mpfr::RoundingMode::Nearest))
+
+#define EXPECT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
+ rounding) \
EXPECT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \
- input, match_value, tolerance))
+ input, match_value, ulp_tolerance, rounding))
+
+#define EXPECT_MPFR_MATCH(...) \
+ GET_MPFR_MACRO(__VA_ARGS__, EXPECT_MPFR_MATCH_ROUNDING, \
+ EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \
+ (__VA_ARGS__)
-#define ASSERT_MPFR_MATCH(op, input, match_value, tolerance) \
+#define ASSERT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \
+ ASSERT_THAT(match_value, \
+ __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \
+ input, match_value, ulp_tolerance, \
+ __llvm_libc::testing::mpfr::RoundingMode::Nearest))
+
+#define ASSERT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
+ rounding) \
ASSERT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \
- input, match_value, tolerance))
+ input, match_value, ulp_tolerance, rounding))
+
+#define ASSERT_MPFR_MATCH(...) \
+ GET_MPFR_MACRO(__VA_ARGS__, ASSERT_MPFR_MATCH_ROUNDING, \
+ ASSERT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \
+ (__VA_ARGS__)
#endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H
More information about the libc-commits
mailing list