[libc-commits] [libc] 3f4674a - [libc] Extend MPFRMatcher to handle multiple-input-multiple-output functions.

Siva Chandra Reddy via libc-commits libc-commits at lists.llvm.org
Tue Aug 25 21:43:10 PDT 2020


Author: Siva Chandra Reddy
Date: 2020-08-25T21:42:49-07:00
New Revision: 3f4674a5577dcc63a846d33f61e9bd95e388223d

URL: https://github.com/llvm/llvm-project/commit/3f4674a5577dcc63a846d33f61e9bd95e388223d
DIFF: https://github.com/llvm/llvm-project/commit/3f4674a5577dcc63a846d33f61e9bd95e388223d.diff

LOG: [libc] Extend MPFRMatcher to handle multiple-input-multiple-output functions.

Tests for frexp[f|l] now use the new capability. Not all input-output
combinations have been addressed by this change. Support for newer combinations
can be added in future as needed.

Reviewed By: lntue

Differential Revision: https://reviews.llvm.org/D86506

Added: 
    

Modified: 
    libc/test/src/math/CMakeLists.txt
    libc/test/src/math/frexp_test.cpp
    libc/test/src/math/frexpf_test.cpp
    libc/test/src/math/frexpl_test.cpp
    libc/utils/MPFRWrapper/MPFRUtils.cpp
    libc/utils/MPFRWrapper/MPFRUtils.h

Removed: 
    


################################################################################
diff  --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt
index e73de5403564..2fe766a2ffc6 100644
--- a/libc/test/src/math/CMakeLists.txt
+++ b/libc/test/src/math/CMakeLists.txt
@@ -333,6 +333,7 @@ add_fp_unittest(
 
 add_fp_unittest(
   frexp_test
+  NEED_MPFR
   SUITE
     libc_math_unittests
   SRCS
@@ -345,6 +346,7 @@ add_fp_unittest(
 
 add_fp_unittest(
   frexpf_test
+  NEED_MPFR
   SUITE
     libc_math_unittests
   SRCS
@@ -357,6 +359,7 @@ add_fp_unittest(
 
 add_fp_unittest(
   frexpl_test
+  NEED_MPFR
   SUITE
     libc_math_unittests
   SRCS

diff  --git a/libc/test/src/math/frexp_test.cpp b/libc/test/src/math/frexp_test.cpp
index f828d515a688..360bbf237560 100644
--- a/libc/test/src/math/frexp_test.cpp
+++ b/libc/test/src/math/frexp_test.cpp
@@ -11,13 +11,18 @@
 #include "utils/FPUtil/BasicOperations.h"
 #include "utils/FPUtil/BitPatterns.h"
 #include "utils/FPUtil/ClassificationFunctions.h"
+#include "utils/FPUtil/FPBits.h"
 #include "utils/FPUtil/FloatOperations.h"
 #include "utils/FPUtil/FloatProperties.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
 #include "utils/UnitTest/Test.h"
 
+using FPBits = __llvm_libc::fputil::FPBits<double>;
 using __llvm_libc::fputil::valueAsBits;
 using __llvm_libc::fputil::valueFromBits;
 
+namespace mpfr = __llvm_libc::testing::mpfr;
+
 using BitPatterns = __llvm_libc::fputil::BitPatterns<double>;
 using Properties = __llvm_libc::fputil::FloatProperties<double>;
 
@@ -127,17 +132,19 @@ TEST(FrexpTest, SomeIntegers) {
 }
 
 TEST(FrexpTest, InDoubleRange) {
-  using BitsType = Properties::BitsType;
-  constexpr BitsType count = 1000000;
-  constexpr BitsType step = UINT64_MAX / count;
-  for (BitsType i = 0, v = 0; i <= count; ++i, v += step) {
-    double x = valueFromBits(v);
+  using UIntType = FPBits::UIntType;
+  constexpr UIntType count = 1000001;
+  constexpr UIntType step = UIntType(-1) / count;
+  for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+    double x = FPBits(v);
     if (isnan(x) || isinf(x) || x == 0.0)
       continue;
-    int exponent;
-    double frac = __llvm_libc::frexp(x, &exponent);
 
-    ASSERT_TRUE(__llvm_libc::fputil::abs(frac) < 1.0);
-    ASSERT_TRUE(__llvm_libc::fputil::abs(frac) >= 0.5);
+    mpfr::BinaryOutput<double> result;
+    result.f = __llvm_libc::frexp(x, &result.i);
+
+    ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) < 1.0);
+    ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) >= 0.5);
+    ASSERT_MPFR_MATCH(mpfr::Operation::Frexp, x, result, 0.0);
   }
 }

diff  --git a/libc/test/src/math/frexpf_test.cpp b/libc/test/src/math/frexpf_test.cpp
index 3b82c68078ee..1bf0c36cf165 100644
--- a/libc/test/src/math/frexpf_test.cpp
+++ b/libc/test/src/math/frexpf_test.cpp
@@ -11,14 +11,18 @@
 #include "utils/FPUtil/BasicOperations.h"
 #include "utils/FPUtil/BitPatterns.h"
 #include "utils/FPUtil/ClassificationFunctions.h"
+#include "utils/FPUtil/FPBits.h"
 #include "utils/FPUtil/FloatOperations.h"
 #include "utils/FPUtil/FloatProperties.h"
 #include "utils/MPFRWrapper/MPFRUtils.h"
 #include "utils/UnitTest/Test.h"
 
+using FPBits = __llvm_libc::fputil::FPBits<float>;
 using __llvm_libc::fputil::valueAsBits;
 using __llvm_libc::fputil::valueFromBits;
 
+namespace mpfr = __llvm_libc::testing::mpfr;
+
 using BitPatterns = __llvm_libc::fputil::BitPatterns<float>;
 using Properties = __llvm_libc::fputil::FloatProperties<float>;
 
@@ -109,7 +113,7 @@ TEST(FrexpfTest, PowersOfTwo) {
   EXPECT_EQ(exponent, 7);
 }
 
-TEST(FrexpTest, SomeIntegers) {
+TEST(FrexpfTest, SomeIntegers) {
   int exponent;
 
   EXPECT_EQ(valueAsBits(0.75f),
@@ -135,17 +139,19 @@ TEST(FrexpTest, SomeIntegers) {
 }
 
 TEST(FrexpfTest, InFloatRange) {
-  using BitsType = Properties::BitsType;
-  constexpr BitsType count = 1000000;
-  constexpr BitsType step = UINT32_MAX / count;
-  for (BitsType i = 0, v = 0; i <= count; ++i, v += step) {
-    float x = valueFromBits(v);
+  using UIntType = FPBits::UIntType;
+  constexpr UIntType count = 1000001;
+  constexpr UIntType step = UIntType(-1) / count;
+  for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
+    float x = FPBits(v);
     if (isnan(x) || isinf(x) || x == 0.0)
       continue;
-    int exponent;
-    float frac = __llvm_libc::frexpf(x, &exponent);
 
-    ASSERT_TRUE(__llvm_libc::fputil::abs(frac) < 1.0f);
-    ASSERT_TRUE(__llvm_libc::fputil::abs(frac) >= 0.5f);
+    mpfr::BinaryOutput<float> result;
+    result.f = __llvm_libc::frexpf(x, &result.i);
+
+    ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) < 1.0);
+    ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) >= 0.5);
+    ASSERT_MPFR_MATCH(mpfr::Operation::Frexp, x, result, 0.0);
   }
 }

diff  --git a/libc/test/src/math/frexpl_test.cpp b/libc/test/src/math/frexpl_test.cpp
index ace445f0a2de..9846bb84ae27 100644
--- a/libc/test/src/math/frexpl_test.cpp
+++ b/libc/test/src/math/frexpl_test.cpp
@@ -10,10 +10,13 @@
 #include "src/math/frexpl.h"
 #include "utils/FPUtil/BasicOperations.h"
 #include "utils/FPUtil/FPBits.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
 #include "utils/UnitTest/Test.h"
 
 using FPBits = __llvm_libc::fputil::FPBits<long double>;
 
+namespace mpfr = __llvm_libc::testing::mpfr;
+
 TEST(FrexplTest, SpecialNumbers) {
   int exponent;
 
@@ -94,10 +97,11 @@ TEST(FrexplTest, LongDoubleRange) {
     if (isnan(x) || isinf(x) || x == 0.0l)
       continue;
 
-    int exponent;
-    long double frac = __llvm_libc::frexpl(x, &exponent);
+    mpfr::BinaryOutput<long double> result;
+    result.f = __llvm_libc::frexpl(x, &result.i);
 
-    ASSERT_TRUE(__llvm_libc::fputil::abs(frac) < 1.0l);
-    ASSERT_TRUE(__llvm_libc::fputil::abs(frac) >= 0.5l);
+    ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) < 1.0);
+    ASSERT_TRUE(__llvm_libc::fputil::abs(result.f) >= 0.5);
+    ASSERT_MPFR_MATCH(mpfr::Operation::Frexp, x, result, 0.0);
   }
 }

diff  --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp
index a3abfce08bf3..86882d05cc39 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp
@@ -14,6 +14,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
 
+#include <memory>
 #include <mpfr.h>
 #include <stdint.h>
 #include <string>
@@ -65,50 +66,90 @@ class MPFRNumber {
     mpfr_set_sj(value, x, MPFR_RNDN);
   }
 
-  template <typename XType,
-            cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0>
-  MPFRNumber(Operation op, XType rawValue) {
-    mpfr_init2(value, mpfrPrecision);
-    MPFRNumber mpfrInput(rawValue);
-    switch (op) {
-    case Operation::Abs:
-      mpfr_abs(value, mpfrInput.value, MPFR_RNDN);
-      break;
-    case Operation::Ceil:
-      mpfr_ceil(value, mpfrInput.value);
-      break;
-    case Operation::Cos:
-      mpfr_cos(value, mpfrInput.value, MPFR_RNDN);
-      break;
-    case Operation::Exp:
-      mpfr_exp(value, mpfrInput.value, MPFR_RNDN);
-      break;
-    case Operation::Exp2:
-      mpfr_exp2(value, mpfrInput.value, MPFR_RNDN);
-      break;
-    case Operation::Floor:
-      mpfr_floor(value, mpfrInput.value);
-      break;
-    case Operation::Round:
-      mpfr_round(value, mpfrInput.value);
-      break;
-    case Operation::Sin:
-      mpfr_sin(value, mpfrInput.value, MPFR_RNDN);
-      break;
-    case Operation::Sqrt:
-      mpfr_sqrt(value, mpfrInput.value, MPFR_RNDN);
-      break;
-    case Operation::Trunc:
-      mpfr_trunc(value, mpfrInput.value);
-      break;
-    }
-  }
-
   MPFRNumber(const MPFRNumber &other) {
     mpfr_set(value, other.value, MPFR_RNDN);
   }
 
-  ~MPFRNumber() { mpfr_clear(value); }
+  MPFRNumber &operator=(const MPFRNumber &rhs) {
+    mpfr_set(value, rhs.value, MPFR_RNDN);
+    return *this;
+  }
+
+  MPFRNumber abs() const {
+    MPFRNumber result;
+    mpfr_abs(result.value, value, MPFR_RNDN);
+    return result;
+  }
+
+  MPFRNumber ceil() const {
+    MPFRNumber result;
+    mpfr_ceil(result.value, value);
+    return result;
+  }
+
+  MPFRNumber cos() const {
+    MPFRNumber result;
+    mpfr_cos(result.value, value, MPFR_RNDN);
+    return result;
+  }
+
+  MPFRNumber exp() const {
+    MPFRNumber result;
+    mpfr_exp(result.value, value, MPFR_RNDN);
+    return result;
+  }
+
+  MPFRNumber exp2() const {
+    MPFRNumber result;
+    mpfr_exp2(result.value, value, MPFR_RNDN);
+    return result;
+  }
+
+  MPFRNumber floor() const {
+    MPFRNumber result;
+    mpfr_floor(result.value, value);
+    return result;
+  }
+
+  MPFRNumber frexp(int &exp) {
+    MPFRNumber result;
+    mpfr_exp_t resultExp;
+    mpfr_frexp(&resultExp, result.value, value, MPFR_RNDN);
+    exp = resultExp;
+    return result;
+  }
+
+  MPFRNumber remquo(const MPFRNumber &divisor, int &quotient) {
+    MPFRNumber remainder;
+    long q;
+    mpfr_remquo(remainder.value, &q, value, divisor.value, MPFR_RNDN);
+    quotient = q;
+    return remainder;
+  }
+
+  MPFRNumber round() const {
+    MPFRNumber result;
+    mpfr_round(result.value, value);
+    return result;
+  }
+
+  MPFRNumber sin() const {
+    MPFRNumber result;
+    mpfr_sin(result.value, value, MPFR_RNDN);
+    return result;
+  }
+
+  MPFRNumber sqrt() const {
+    MPFRNumber result;
+    mpfr_sqrt(result.value, value, MPFR_RNDN);
+    return result;
+  }
+
+  MPFRNumber trunc() const {
+    MPFRNumber result;
+    mpfr_trunc(result.value, value);
+    return result;
+  }
 
   std::string str() const {
     // 200 bytes should be more than sufficient to hold a 100-digit number
@@ -179,10 +220,65 @@ class MPFRNumber {
 
 namespace internal {
 
+template <typename InputType>
+cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
+unaryOperation(Operation op, InputType input) {
+  MPFRNumber mpfrInput(input);
+  switch (op) {
+  case Operation::Abs:
+    return mpfrInput.abs();
+  case Operation::Ceil:
+    return mpfrInput.ceil();
+  case Operation::Cos:
+    return mpfrInput.cos();
+  case Operation::Exp:
+    return mpfrInput.exp();
+  case Operation::Exp2:
+    return mpfrInput.exp2();
+  case Operation::Floor:
+    return mpfrInput.floor();
+  case Operation::Round:
+    return mpfrInput.round();
+  case Operation::Sin:
+    return mpfrInput.sin();
+  case Operation::Sqrt:
+    return mpfrInput.sqrt();
+  case Operation::Trunc:
+    return mpfrInput.trunc();
+  default:
+    __builtin_unreachable();
+  }
+}
+
+template <typename InputType>
+cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
+unaryOperationTwoOutputs(Operation op, InputType input, int &output) {
+  MPFRNumber mpfrInput(input);
+  switch (op) {
+  case Operation::Frexp:
+    return mpfrInput.frexp(output);
+  default:
+    __builtin_unreachable();
+  }
+}
+
+template <typename InputType>
+cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
+binaryOperationTwoOutputs(Operation op, InputType x, InputType y, int &output) {
+  MPFRNumber inputX(x), inputY(y);
+  switch (op) {
+  case Operation::RemQuo:
+    return inputX.remquo(inputY, output);
+  default:
+    __builtin_unreachable();
+  }
+}
+
 template <typename T>
-void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) {
-  MPFRNumber mpfrResult(operation, input);
+void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
+                                            testutils::StreamWrapper &OS) {
   MPFRNumber mpfrInput(input);
+  MPFRNumber mpfrResult = unaryOperation(op, input);
   MPFRNumber mpfrMatchValue(matchValue);
   FPBits<T> inputBits(input);
   FPBits<T> matchBits(matchValue);
@@ -201,25 +297,174 @@ void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) {
      << '\n';
 }
 
-template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &);
-template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &);
 template void
-MPFRMatcher<long double>::explainError(testutils::StreamWrapper &);
+explainUnaryOperationSingleOutputError<float>(Operation op, float, float,
+                                              testutils::StreamWrapper &);
+template void
+explainUnaryOperationSingleOutputError<double>(Operation op, double, double,
+                                               testutils::StreamWrapper &);
+template void explainUnaryOperationSingleOutputError<long double>(
+    Operation op, long double, long double, testutils::StreamWrapper &);
+
+template <typename T>
+void explainUnaryOperationTwoOutputsError(Operation op, T input,
+                                          const BinaryOutput<T> &libcResult,
+                                          testutils::StreamWrapper &OS) {
+  MPFRNumber mpfrInput(input);
+  FPBits<T> inputBits(input);
+  int mpfrIntResult;
+  MPFRNumber mpfrResult = unaryOperationTwoOutputs(op, input, mpfrIntResult);
+
+  if (mpfrIntResult != libcResult.i) {
+    OS << "MPFR integral result: " << mpfrIntResult << '\n'
+       << "Libc integral result: " << libcResult.i << '\n';
+  } else {
+    OS << "Integral result from libc matches integral result from MPFR.\n";
+  }
+
+  MPFRNumber mpfrMatchValue(libcResult.f);
+  OS << "Libc floating point result is not within tolerance value of the MPFR "
+     << "result.\n\n";
+
+  OS << "            Input decimal: " << mpfrInput.str() << "\n\n";
+
+  OS << "Libc floating point value: " << mpfrMatchValue.str() << '\n';
+  __llvm_libc::fputil::testing::describeValue(
+      " Libc floating point bits: ", libcResult.f, OS);
+  OS << "\n\n";
+
+  OS << "              MPFR result: " << mpfrResult.str() << '\n';
+  __llvm_libc::fputil::testing::describeValue(
+      "             MPFR rounded: ", mpfrResult.as<T>(), OS);
+  OS << '\n'
+     << "                ULP error: "
+     << std::to_string(mpfrResult.ulp(libcResult.f)) << '\n';
+}
+
+template void explainUnaryOperationTwoOutputsError<float>(
+    Operation, float, const BinaryOutput<float> &, testutils::StreamWrapper &);
+template void
+explainUnaryOperationTwoOutputsError<double>(Operation, double,
+                                             const BinaryOutput<double> &,
+                                             testutils::StreamWrapper &);
+template void explainUnaryOperationTwoOutputsError<long double>(
+    Operation, long double, const BinaryOutput<long double> &,
+    testutils::StreamWrapper &);
 
 template <typename T>
-bool compare(Operation op, T input, T libcResult, double ulpError) {
+void explainBinaryOperationTwoOutputsError(Operation op,
+                                           const BinaryInput<T> &input,
+                                           const BinaryOutput<T> &libcResult,
+                                           testutils::StreamWrapper &OS) {
+  MPFRNumber mpfrX(input.x);
+  MPFRNumber mpfrY(input.y);
+  FPBits<T> xbits(input.x);
+  FPBits<T> ybits(input.y);
+  int mpfrIntResult;
+  MPFRNumber mpfrResult =
+      binaryOperationTwoOutputs(op, input.x, input.y, mpfrIntResult);
+  MPFRNumber mpfrMatchValue(libcResult.f);
+
+  OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n'
+     << "MPFR integral result: " << mpfrIntResult << '\n'
+     << "Libc integral result: " << libcResult.i << '\n'
+     << "Libc floating point result: " << mpfrMatchValue.str() << '\n'
+     << "               MPFR result: " << mpfrResult.str() << '\n';
+  __llvm_libc::fputil::testing::describeValue(
+      "Libc floating point result bits: ", libcResult.f, OS);
+  __llvm_libc::fputil::testing::describeValue(
+      "              MPFR rounded bits: ", mpfrResult.as<T>(), OS);
+  OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult.f)) << '\n';
+}
+
+template void explainBinaryOperationTwoOutputsError<float>(
+    Operation, const BinaryInput<float> &, const BinaryOutput<float> &,
+    testutils::StreamWrapper &);
+template void explainBinaryOperationTwoOutputsError<double>(
+    Operation, const BinaryInput<double> &, const BinaryOutput<double> &,
+    testutils::StreamWrapper &);
+template void explainBinaryOperationTwoOutputsError<long double>(
+    Operation, const BinaryInput<long double> &,
+    const BinaryOutput<long double> &, testutils::StreamWrapper &);
+
+template <typename T>
+bool compareUnaryOperationSingleOutput(Operation op, T input, T libcResult,
+                                       double ulpError) {
   // If the ulp error is exactly 0.5 (i.e a tie), we would check that the result
   // is rounded to the nearest even.
-  MPFRNumber mpfrResult(op, input);
+  MPFRNumber mpfrResult = unaryOperation(op, input);
   double ulp = mpfrResult.ulp(libcResult);
   bool bitsAreEven = ((FPBits<T>(libcResult).bitsAsUInt() & 1) == 0);
   return (ulp < ulpError) ||
          ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
 }
 
-template bool compare<float>(Operation, float, float, double);
-template bool compare<double>(Operation, double, double, double);
-template bool compare<long double>(Operation, long double, long double, double);
+template bool compareUnaryOperationSingleOutput<float>(Operation, float, float,
+                                                       double);
+template bool compareUnaryOperationSingleOutput<double>(Operation, double,
+                                                        double, double);
+template bool compareUnaryOperationSingleOutput<long double>(Operation,
+                                                             long double,
+                                                             long double,
+                                                             double);
+
+template <typename T>
+bool compareUnaryOperationTwoOutputs(Operation op, T input,
+                                     const BinaryOutput<T> &libcResult,
+                                     double ulpError) {
+  int mpfrIntResult;
+  MPFRNumber mpfrResult = unaryOperationTwoOutputs(op, input, mpfrIntResult);
+  double ulp = mpfrResult.ulp(libcResult.f);
+
+  if (mpfrIntResult != libcResult.i)
+    return false;
+
+  bool bitsAreEven = ((FPBits<T>(libcResult.f).bitsAsUInt() & 1) == 0);
+  return (ulp < ulpError) ||
+         ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
+}
+
+template bool
+compareUnaryOperationTwoOutputs<float>(Operation, float,
+                                       const BinaryOutput<float> &, double);
+template bool
+compareUnaryOperationTwoOutputs<double>(Operation, double,
+                                        const BinaryOutput<double> &, double);
+template bool compareUnaryOperationTwoOutputs<long double>(
+    Operation, long double, const BinaryOutput<long double> &, double);
+
+template <typename T>
+bool compareBinaryOperationTwoOutputs(Operation op, const BinaryInput<T> &input,
+                                      const BinaryOutput<T> &libcResult,
+                                      double ulpError) {
+  int mpfrIntResult;
+  MPFRNumber mpfrResult =
+      binaryOperationTwoOutputs(op, input.x, input.y, mpfrIntResult);
+  double ulp = mpfrResult.ulp(libcResult.f);
+
+  if (mpfrIntResult != libcResult.i) {
+    if (op == Operation::RemQuo) {
+      if ((0x7 & mpfrIntResult) != libcResult.i)
+        return false;
+    } else {
+      return false;
+    }
+  }
+
+  bool bitsAreEven = ((FPBits<T>(libcResult.f).bitsAsUInt() & 1) == 0);
+  return (ulp < ulpError) ||
+         ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
+}
+
+template bool
+compareBinaryOperationTwoOutputs<float>(Operation, const BinaryInput<float> &,
+                                        const BinaryOutput<float> &, double);
+template bool
+compareBinaryOperationTwoOutputs<double>(Operation, const BinaryInput<double> &,
+                                         const BinaryOutput<double> &, double);
+template bool compareBinaryOperationTwoOutputs<long double>(
+    Operation, const BinaryInput<long double> &,
+    const BinaryOutput<long double> &, double);
 
 } // namespace internal
 

diff  --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h
index 3d94079e65d8..b46f09dd5e55 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.h
+++ b/libc/utils/MPFRWrapper/MPFRUtils.h
@@ -19,6 +19,10 @@ namespace testing {
 namespace mpfr {
 
 enum class Operation : int {
+  // Operations with take a single floating point number as input
+  // and produce a single floating point number as output. The input
+  // and output floating point numbers are of the same kind.
+  BeginUnaryOperationsSingleOutput,
   Abs,
   Ceil,
   Cos,
@@ -28,45 +32,193 @@ enum class Operation : int {
   Round,
   Sin,
   Sqrt,
-  Trunc
+  Trunc,
+  EndUnaryOperationsSingleOutput,
+
+  // Operations which take a single floating point nubmer as input
+  // but produce two outputs. The first ouput is a floating point
+  // number of the same type as the input. The second output is of type
+  // 'int'.
+  BeginUnaryOperationsTwoOutputs,
+  Frexp, // Floating point output, the first output, is the fractional part.
+  EndUnaryOperationsTwoOutputs,
+
+  // Operations wich take two floating point nubmers of the same type as
+  // input and produce a single floating point number of the same type as
+  // output.
+  BeginBinaryOperationsSingleOutput,
+  // TODO: Add operations like hypot.
+  EndBinaryOperationsSingleOutput,
+
+  // Operations which take two floating point numbers of the same type as
+  // input and produce two outputs. The first output is a floating nubmer of
+  // the same type as the inputs. The second output is af type 'int'.
+  BeginBinaryOperationsTwoOutputs,
+  RemQuo, // The first output, the floating point output, is the remainder.
+  EndBinaryOperationsTwoOutputs,
+
+  BeginTernaryOperationsSingleOuput,
+  // TODO: Add operations like fma.
+  EndTernaryOperationsSingleOutput,
+};
+
+template <typename T> struct BinaryInput {
+  static_assert(
+      __llvm_libc::cpp::IsFloatingPointType<T>::Value,
+      "Template parameter of BinaryInput must be a floating point type.");
+
+  using Type = T;
+  T x, y;
+};
+
+template <typename T> struct TernaryInput {
+  static_assert(
+      __llvm_libc::cpp::IsFloatingPointType<T>::Value,
+      "Template parameter of TernaryInput must be a floating point type.");
+
+  using Type = T;
+  T x, y, z;
+};
+
+template <typename T> struct BinaryOutput {
+  T f;
+  int i;
 };
 
 namespace internal {
 
+template <typename T1, typename T2>
+struct AreMatchingBinaryInputAndBinaryOutput {
+  static constexpr bool value = false;
+};
+
 template <typename T>
-bool compare(Operation op, T input, T libcOutput, double t);
+struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
+  static constexpr bool value = cpp::IsFloatingPointType<T>::Value;
+};
 
-template <typename T> class MPFRMatcher : public testing::Matcher<T> {
-  static_assert(__llvm_libc::cpp::IsFloatingPointType<T>::Value,
-                "MPFRMatcher can only be used with floating point values.");
+template <typename T>
+bool compareUnaryOperationSingleOutput(Operation op, T input, T libcOutput,
+                                       double t);
+template <typename T>
+bool compareUnaryOperationTwoOutputs(Operation op, T input,
+                                     const BinaryOutput<T> &libcOutput,
+                                     double t);
+template <typename T>
+bool compareBinaryOperationTwoOutputs(Operation op, const BinaryInput<T> &input,
+                                      const BinaryOutput<T> &libcOutput,
+                                      double t);
 
-  Operation operation;
-  T input;
-  T matchValue;
+template <typename T>
+void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
+                                            testutils::StreamWrapper &OS);
+template <typename T>
+void explainUnaryOperationTwoOutputsError(Operation op, T input,
+                                          const BinaryOutput<T> &matchValue,
+                                          testutils::StreamWrapper &OS);
+template <typename T>
+void explainBinaryOperationTwoOutputsError(Operation op,
+                                           const BinaryInput<T> &input,
+                                           const BinaryOutput<T> &matchValue,
+                                           testutils::StreamWrapper &OS);
+
+template <Operation op, typename InputType, typename OutputType>
+class MPFRMatcher : public testing::Matcher<OutputType> {
+  InputType input;
+  OutputType matchValue;
   double ulpTolerance;
 
 public:
-  MPFRMatcher(Operation op, T testInput, double ulpTolerance)
-      : operation(op), input(testInput), ulpTolerance(ulpTolerance) {}
+  MPFRMatcher(InputType testInput, double ulpTolerance)
+      : input(testInput), ulpTolerance(ulpTolerance) {}
 
-  bool match(T libcResult) {
+  bool match(OutputType libcResult) {
     matchValue = libcResult;
-    return internal::compare(operation, input, libcResult, ulpTolerance);
+    return match(input, matchValue, ulpTolerance);
   }
 
-  void explainError(testutils::StreamWrapper &OS) override;
+  void explainError(testutils::StreamWrapper &OS) override {
+    explainError(input, matchValue, OS);
+  }
+
+private:
+  template <typename T> static bool match(T in, T out, double tolerance) {
+    return compareUnaryOperationSingleOutput(op, in, out, tolerance);
+  }
+
+  template <typename T>
+  static bool match(T in, const BinaryOutput<T> &out, double tolerance) {
+    return compareUnaryOperationTwoOutputs(op, in, out, tolerance);
+  }
+
+  template <typename T>
+  static bool match(const BinaryInput<T> &in, T out, double tolerance) {
+    // TODO: Implement the comparision function and error reporter.
+  }
+
+  template <typename T>
+  static bool match(BinaryInput<T> in, const BinaryOutput<T> &out,
+                    double tolerance) {
+    return compareBinaryOperationTwoOutputs(op, in, out, tolerance);
+  }
+
+  template <typename T>
+  static bool match(const TernaryInput<T> &in, T out, double tolerance) {
+    // TODO: Implement the comparision function and error reporter.
+  }
+
+  template <typename T>
+  static void explainError(T in, T out, testutils::StreamWrapper &OS) {
+    explainUnaryOperationSingleOutputError(op, in, out, OS);
+  }
+
+  template <typename T>
+  static void explainError(T in, const BinaryOutput<T> &out,
+                           testutils::StreamWrapper &OS) {
+    explainUnaryOperationTwoOutputsError(op, in, out, OS);
+  }
+
+  template <typename T>
+  static void explainError(const BinaryInput<T> &in, const BinaryOutput<T> &out,
+                           testutils::StreamWrapper &OS) {
+    explainBinaryOperationTwoOutputsError(op, in, out, OS);
+  }
 };
 
 } // namespace internal
 
-template <typename T, typename U>
+// Return true if the input and ouput types for the operation op are valid
+// types.
+template <Operation op, typename InputType, typename OutputType>
+constexpr bool isValidOperation() {
+  return (Operation::BeginUnaryOperationsSingleOutput < op &&
+          op < Operation::EndUnaryOperationsSingleOutput &&
+          cpp::IsSame<InputType, OutputType>::Value &&
+          cpp::IsFloatingPointType<InputType>::Value) ||
+         (Operation::BeginUnaryOperationsTwoOutputs < op &&
+          op < Operation::EndUnaryOperationsTwoOutputs &&
+          cpp::IsFloatingPointType<InputType>::Value &&
+          cpp::IsSame<OutputType, BinaryOutput<InputType>>::Value) ||
+         (Operation::BeginBinaryOperationsSingleOutput < op &&
+          op < Operation::EndBinaryOperationsSingleOutput &&
+          cpp::IsFloatingPointType<OutputType>::Value &&
+          cpp::IsSame<InputType, BinaryInput<OutputType>>::Value) ||
+         (Operation::BeginBinaryOperationsTwoOutputs < op &&
+          op < Operation::EndBinaryOperationsTwoOutputs &&
+          internal::AreMatchingBinaryInputAndBinaryOutput<InputType,
+                                                          OutputType>::value) ||
+         (Operation::BeginTernaryOperationsSingleOuput < op &&
+          op < Operation::EndTernaryOperationsSingleOutput &&
+          cpp::IsFloatingPointType<OutputType>::Value &&
+          cpp::IsSame<InputType, TernaryInput<OutputType>>::Value);
+}
+
+template <Operation op, typename InputType, typename OutputType>
 __attribute__((no_sanitize("address")))
-typename cpp::EnableIfType<cpp::IsSameV<U, double>, internal::MPFRMatcher<T>>
-getMPFRMatcher(Operation op, T input, U t) {
-  static_assert(
-      __llvm_libc::cpp::IsFloatingPointType<T>::Value,
-      "getMPFRMatcher can only be used to match floating point results.");
-  return internal::MPFRMatcher<T>(op, input, t);
+cpp::EnableIfType<isValidOperation<op, InputType, OutputType>(),
+                  internal::MPFRMatcher<op, InputType, OutputType>>
+getMPFRMatcher(InputType input, OutputType outputUnused, double t) {
+  return internal::MPFRMatcher<op, InputType, OutputType>(input, t);
 }
 
 } // namespace mpfr
@@ -74,11 +226,11 @@ getMPFRMatcher(Operation op, T input, U t) {
 } // namespace __llvm_libc
 
 #define EXPECT_MPFR_MATCH(op, input, matchValue, tolerance)                    \
-  EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher(          \
-                              op, input, tolerance))
+  EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>(      \
+                              input, matchValue, tolerance))
 
 #define ASSERT_MPFR_MATCH(op, input, matchValue, tolerance)                    \
-  ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher(          \
-                              op, input, tolerance))
+  ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>(      \
+                              input, matchValue, tolerance))
 
 #endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H


        


More information about the libc-commits mailing list