[libc-commits] [libc] 4726bec - [libc] Add implementation of fmaf.

Tue Ly via libc-commits libc-commits at lists.llvm.org
Wed Jan 6 14:14:56 PST 2021


Author: Tue Ly
Date: 2021-01-06T17:14:20-05:00
New Revision: 4726bec8f29bd535e2709b491d223d42bd20c120

URL: https://github.com/llvm/llvm-project/commit/4726bec8f29bd535e2709b491d223d42bd20c120
DIFF: https://github.com/llvm/llvm-project/commit/4726bec8f29bd535e2709b491d223d42bd20c120.diff

LOG: [libc] Add implementation of fmaf.

Differential Revision: https://reviews.llvm.org/D94018

Added: 
    libc/src/math/fmaf.cpp
    libc/src/math/fmaf.h
    libc/test/src/math/FmaTest.h
    libc/test/src/math/fmaf_test.cpp

Modified: 
    libc/config/linux/aarch64/entrypoints.txt
    libc/config/linux/x86_64/entrypoints.txt
    libc/spec/stdc.td
    libc/src/math/CMakeLists.txt
    libc/test/src/math/CMakeLists.txt
    libc/utils/FPUtil/FPBits.h
    libc/utils/MPFRWrapper/MPFRUtils.cpp
    libc/utils/MPFRWrapper/MPFRUtils.h

Removed: 
    


################################################################################
diff  --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt
index b9042625e666..0db8c4b39caa 100644
--- a/libc/config/linux/aarch64/entrypoints.txt
+++ b/libc/config/linux/aarch64/entrypoints.txt
@@ -65,6 +65,7 @@ set(TARGET_LIBM_ENTRYPOINTS
     libc.src.math.floor
     libc.src.math.floorf
     libc.src.math.floorl
+    libc.src.math.fmaf
     libc.src.math.fmax
     libc.src.math.fmaxf
     libc.src.math.fmaxl

diff  --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt
index a34c59646149..a80a8b4f105b 100644
--- a/libc/config/linux/x86_64/entrypoints.txt
+++ b/libc/config/linux/x86_64/entrypoints.txt
@@ -106,6 +106,7 @@ set(TARGET_LIBM_ENTRYPOINTS
     libc.src.math.floor
     libc.src.math.floorf
     libc.src.math.floorl
+    libc.src.math.fmaf
     libc.src.math.fmin
     libc.src.math.fminf
     libc.src.math.fminl

diff  --git a/libc/spec/stdc.td b/libc/spec/stdc.td
index 41f6083a2336..e89d16633ae3 100644
--- a/libc/spec/stdc.td
+++ b/libc/spec/stdc.td
@@ -322,6 +322,8 @@ def StdC : StandardSpec<"stdc"> {
           FunctionSpec<"fmaxf", RetValSpec<FloatType>, [ArgSpec<FloatType>, ArgSpec<FloatType>]>,
           FunctionSpec<"fmaxl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>, ArgSpec<LongDoubleType>]>,
 
+          FunctionSpec<"fmaf", RetValSpec<FloatType>, [ArgSpec<FloatType>, ArgSpec<FloatType>, ArgSpec<FloatType>]>,
+
           FunctionSpec<"frexp", RetValSpec<DoubleType>, [ArgSpec<DoubleType>, ArgSpec<IntPtr>]>,
           FunctionSpec<"frexpf", RetValSpec<FloatType>, [ArgSpec<FloatType>, ArgSpec<IntPtr>]>,
           FunctionSpec<"frexpl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>, ArgSpec<IntPtr>]>,

diff  --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt
index 15a2e4645edd..34b7dfcd4306 100644
--- a/libc/src/math/CMakeLists.txt
+++ b/libc/src/math/CMakeLists.txt
@@ -978,3 +978,14 @@ add_entrypoint_object(
     -O2
 )
 
+add_entrypoint_object(
+  fmaf
+  SRCS
+    fmaf.cpp
+  HDRS
+    fmaf.h
+  DEPENDS
+    libc.utils.FPUtil.fputil
+  COMPILE_OPTIONS
+    -O2
+)

diff  --git a/libc/src/math/fmaf.cpp b/libc/src/math/fmaf.cpp
new file mode 100644
index 000000000000..1860d887d630
--- /dev/null
+++ b/libc/src/math/fmaf.cpp
@@ -0,0 +1,64 @@
+//===-- Implementation of fmaf function -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/__support/common.h"
+
+#include "utils/FPUtil/FEnv.h"
+#include "utils/FPUtil/FPBits.h"
+
+namespace __llvm_libc {
+
+float LLVM_LIBC_ENTRYPOINT(fmaf)(float x, float y, float z) {
+  // Product is exact.
+  double prod = static_cast<double>(x) * static_cast<double>(y);
+  double z_d = static_cast<double>(z);
+  double sum = prod + z_d;
+  fputil::FPBits<double> bit_prod(prod), bitz(z_d), bit_sum(sum);
+
+  if (!(bit_sum.isInfOrNaN() || bit_sum.isZero())) {
+    // Since the sum is computed in double precision, rounding might happen
+    // (for instance, when bitz.exponent > bit_prod.exponent + 5, or
+    // bit_prod.exponent > bitz.exponent + 40).  In that case, when we round
+    // the sum back to float, double rounding error might occur.
+    // A concrete example of this phenomenon is as follows:
+    //   x = y = 1 + 2^(-12), z = 2^(-53)
+    // The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53)
+    // So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23)
+    // On the other hand, with the default rounding mode,
+    //   double(x*y + z) = 1 + 2^(-11) + 2^(-24)
+    // and casting again to float gives us:
+    //   float(double(x*y + z)) = 1 + 2^(-11).
+    //
+    // In order to correct this possible double rounding error, first we use
+    // Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly,
+    // assuming the (default) rounding mode is round-to-the-nearest,
+    // tie-to-even.  Moreover, t satisfies the condition that t < eps(sum),
+    // i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding
+    // occurs when computing the sum, we just need to use t to adjust (any) last
+    // bit of sum, so that the sticky bits used when rounding sum to float are
+    // correct (when it matters).
+    fputil::FPBits<double> t(
+        (bit_prod.exponent >= bitz.exponent)
+            ? ((static_cast<double>(bit_sum) - bit_prod) - bitz)
+            : ((static_cast<double>(bit_sum) - bitz) - bit_prod));
+
+    // Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are
+    // zero.
+    if (!t.isZero() && ((bit_sum.mantissa & 0xfff'ffffULL) == 0)) {
+      if (bit_sum.sign != t.sign) {
+        ++bit_sum.mantissa;
+      } else if (bit_sum.mantissa) {
+        --bit_sum.mantissa;
+      }
+    }
+  }
+
+  return static_cast<float>(static_cast<double>(bit_sum));
+}
+
+} // namespace __llvm_libc

diff  --git a/libc/src/math/fmaf.h b/libc/src/math/fmaf.h
new file mode 100644
index 000000000000..48fbb65d6650
--- /dev/null
+++ b/libc/src/math/fmaf.h
@@ -0,0 +1,18 @@
+//===-- Implementation header for fmaf --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_FMAF_H
+#define LLVM_LIBC_SRC_MATH_FMAF_H
+
+namespace __llvm_libc {
+
+float fmaf(float x, float y, float z);
+
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_SRC_MATH_FMAF_H

diff  --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt
index 3f4d8c76d7c0..48f8902e3f59 100644
--- a/libc/test/src/math/CMakeLists.txt
+++ b/libc/test/src/math/CMakeLists.txt
@@ -1049,3 +1049,16 @@ add_fp_unittest(
     libc.src.math.nextafterl
     libc.utils.FPUtil.fputil
 )
+
+add_fp_unittest(
+  fmaf_test
+  NEED_MPFR
+  SUITE
+    libc_math_unittests
+  SRCS
+    fmaf_test.cpp
+  DEPENDS
+    libc.include.math
+    libc.src.math.fmaf
+    libc.utils.FPUtil.fputil
+)

diff  --git a/libc/test/src/math/FmaTest.h b/libc/test/src/math/FmaTest.h
new file mode 100644
index 000000000000..c39c4ad0f1da
--- /dev/null
+++ b/libc/test/src/math/FmaTest.h
@@ -0,0 +1,94 @@
+//===-- Utility class to test 
diff erent flavors of fma --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_TEST_SRC_MATH_FMATEST_H
+#define LLVM_LIBC_TEST_SRC_MATH_FMATEST_H
+
+#include "utils/FPUtil/FPBits.h"
+#include "utils/FPUtil/TestHelpers.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+#include "utils/UnitTest/Test.h"
+
+#include <random>
+
+namespace mpfr = __llvm_libc::testing::mpfr;
+
+template <typename T>
+class FmaTestTemplate : public __llvm_libc::testing::Test {
+private:
+  using Func = T (*)(T, T, T);
+  using FPBits = __llvm_libc::fputil::FPBits<T>;
+  using UIntType = typename FPBits::UIntType;
+  const T nan = __llvm_libc::fputil::FPBits<T>::buildNaN(1);
+  const T inf = __llvm_libc::fputil::FPBits<T>::inf();
+  const T negInf = __llvm_libc::fputil::FPBits<T>::negInf();
+  const T zero = __llvm_libc::fputil::FPBits<T>::zero();
+  const T negZero = __llvm_libc::fputil::FPBits<T>::negZero();
+
+  UIntType getRandomBitPattern() {
+    UIntType bits{0};
+    for (size_t i = 0; i < sizeof(UIntType) / 2; ++i) {
+      bits = (bits << 2) + static_cast<uint16_t>(std::rand());
+    }
+    return bits;
+  }
+
+public:
+  void testSpecialNumbers(Func func) {
+    EXPECT_FP_EQ(func(zero, zero, zero), zero);
+    EXPECT_FP_EQ(func(zero, negZero, negZero), negZero);
+    EXPECT_FP_EQ(func(inf, inf, zero), inf);
+    EXPECT_FP_EQ(func(negInf, inf, negInf), negInf);
+    EXPECT_FP_EQ(func(inf, zero, zero), nan);
+    EXPECT_FP_EQ(func(inf, negInf, inf), nan);
+    EXPECT_FP_EQ(func(nan, zero, inf), nan);
+    EXPECT_FP_EQ(func(inf, negInf, nan), nan);
+
+    // Test underflow rounding up.
+    EXPECT_FP_EQ(func(T(0.5), FPBits(FPBits::minSubnormal),
+                      FPBits(FPBits::minSubnormal)),
+                 FPBits(UIntType(2)));
+    // Test underflow rounding down.
+    FPBits v(FPBits::minNormal + UIntType(1));
+    EXPECT_FP_EQ(
+        func(T(1) / T(FPBits::minNormal << 1), v, FPBits(FPBits::minNormal)),
+        v);
+    // Test overflow.
+    FPBits z(FPBits::maxNormal);
+    EXPECT_FP_EQ(func(T(1.75), z, -z), T(0.75) * z);
+  }
+
+  void testSubnormalRange(Func func) {
+    constexpr UIntType count = 1000001;
+    constexpr UIntType step =
+        (FPBits::maxSubnormal - FPBits::minSubnormal) / count;
+    for (UIntType v = FPBits::minSubnormal, w = FPBits::maxSubnormal;
+         v <= FPBits::maxSubnormal && w >= FPBits::minSubnormal;
+         v += step, w -= step) {
+      T x = FPBits(getRandomBitPattern()), y = FPBits(v), z = FPBits(w);
+      T result = func(x, y, z);
+      mpfr::TernaryInput<T> input{x, y, z};
+      ASSERT_MPFR_MATCH(mpfr::Operation::Fma, input, result, 0.5);
+    }
+  }
+
+  void testNormalRange(Func func) {
+    constexpr UIntType count = 1000001;
+    constexpr UIntType step = (FPBits::maxNormal - FPBits::minNormal) / count;
+    for (UIntType v = FPBits::minNormal, w = FPBits::maxNormal;
+         v <= FPBits::maxNormal && w >= FPBits::minNormal;
+         v += step, w -= step) {
+      T x = FPBits(v), y = FPBits(w), z = FPBits(getRandomBitPattern());
+      T result = func(x, y, z);
+      mpfr::TernaryInput<T> input{x, y, z};
+      ASSERT_MPFR_MATCH(mpfr::Operation::Fma, input, result, 0.5);
+    }
+  }
+};
+
+#endif // LLVM_LIBC_TEST_SRC_MATH_FMATEST_H

diff  --git a/libc/test/src/math/fmaf_test.cpp b/libc/test/src/math/fmaf_test.cpp
new file mode 100644
index 000000000000..96d3ca7cf113
--- /dev/null
+++ b/libc/test/src/math/fmaf_test.cpp
@@ -0,0 +1,19 @@
+//===-- Unittests for fmaf ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "FmaTest.h"
+
+#include "src/math/fmaf.h"
+
+using FmaTest = FmaTestTemplate<float>;
+
+TEST_F(FmaTest, SpecialNumbers) { testSpecialNumbers(&__llvm_libc::fmaf); }
+
+TEST_F(FmaTest, SubnormalRange) { testSubnormalRange(&__llvm_libc::fmaf); }
+
+TEST_F(FmaTest, NormalRange) { testNormalRange(&__llvm_libc::fmaf); }

diff  --git a/libc/utils/FPUtil/FPBits.h b/libc/utils/FPUtil/FPBits.h
index 89bdd92669b8..02c0dc880579 100644
--- a/libc/utils/FPUtil/FPBits.h
+++ b/libc/utils/FPUtil/FPBits.h
@@ -84,7 +84,10 @@ template <typename T> struct __attribute__((packed)) FPBits {
   // We don't want accidental type promotions/conversions so we require exact
   // type match.
   template <typename XType,
-            cpp::EnableIfType<cpp::IsSame<T, XType>::Value, int> = 0>
+            cpp::EnableIfType<cpp::IsSame<T, XType>::Value ||
+                                  (cpp::IsIntegral<XType>::Value &&
+                                   (sizeof(XType) == sizeof(UIntType))),
+                              int> = 0>
   explicit FPBits(XType x) {
     *this = *reinterpret_cast<FPBits<T> *>(&x);
   }
@@ -106,13 +109,6 @@ template <typename T> struct __attribute__((packed)) FPBits {
   // the potential software implementations of UIntType will not slow real
   // code.
 
-  template <typename XType,
-            cpp::EnableIfType<cpp::IsSame<UIntType, XType>::Value, int> = 0>
-  explicit FPBits<long double>(XType x) {
-    // The last 4 bytes of v are ignored in case of i386.
-    *this = *reinterpret_cast<FPBits<T> *>(&x);
-  }
-
   UIntType bitsAsUInt() const {
     return *reinterpret_cast<const UIntType *>(this);
   }

diff  --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp
index 3c66ccfe5b38..7cc80a281c6e 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp
@@ -35,48 +35,69 @@ namespace __llvm_libc {
 namespace testing {
 namespace mpfr {
 
+template <typename T> struct Precision;
+
+template <> struct Precision<float> {
+  static constexpr unsigned int value = 24;
+};
+
+template <> struct Precision<double> {
+  static constexpr unsigned int value = 53;
+};
+
+#if !(defined(__x86_64__) || defined(__i386__))
+template <> struct Precision<long double> {
+  static constexpr unsigned int value = 64;
+};
+#else
+template <> struct Precision<long double> {
+  static constexpr unsigned int value = 113;
+};
+#endif
+
 class MPFRNumber {
   // A precision value which allows sufficiently large additional
   // precision even compared to quad-precision floating point values.
-  static constexpr unsigned int mpfrPrecision = 128;
+  unsigned int mpfrPrecision;
 
   mpfr_t value;
 
 public:
-  MPFRNumber() { mpfr_init2(value, mpfrPrecision); }
+  MPFRNumber() : mpfrPrecision(128) { mpfr_init2(value, mpfrPrecision); }
 
   // We use explicit EnableIf specializations to disallow implicit
   // conversions. Implicit conversions can potentially lead to loss of
   // precision.
   template <typename XType,
             cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0>
-  explicit MPFRNumber(XType x) {
+  explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
     mpfr_init2(value, mpfrPrecision);
     mpfr_set_flt(value, x, MPFR_RNDN);
   }
 
   template <typename XType,
             cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0>
-  explicit MPFRNumber(XType x) {
+  explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
     mpfr_init2(value, mpfrPrecision);
     mpfr_set_d(value, x, MPFR_RNDN);
   }
 
   template <typename XType,
             cpp::EnableIfType<cpp::IsSame<long double, XType>::Value, int> = 0>
-  explicit MPFRNumber(XType x) {
+  explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
     mpfr_init2(value, mpfrPrecision);
     mpfr_set_ld(value, x, MPFR_RNDN);
   }
 
   template <typename XType,
             cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0>
-  explicit MPFRNumber(XType x) {
+  explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
     mpfr_init2(value, mpfrPrecision);
     mpfr_set_sj(value, x, MPFR_RNDN);
   }
 
-  MPFRNumber(const MPFRNumber &other) {
+  MPFRNumber(const MPFRNumber &other) : mpfrPrecision(other.mpfrPrecision) {
+    mpfr_init2(value, mpfrPrecision);
     mpfr_set(value, other.value, MPFR_RNDN);
   }
 
@@ -85,6 +106,7 @@ class MPFRNumber {
   }
 
   MPFRNumber &operator=(const MPFRNumber &rhs) {
+    mpfrPrecision = rhs.mpfrPrecision;
     mpfr_set(value, rhs.value, MPFR_RNDN);
     return *this;
   }
@@ -193,6 +215,12 @@ class MPFRNumber {
     return result;
   }
 
+  MPFRNumber fma(const MPFRNumber &b, const MPFRNumber &c) {
+    MPFRNumber result(*this);
+    mpfr_fma(result.value, value, b.value, c.value, MPFR_RNDN);
+    return result;
+  }
+
   std::string str() const {
     // 200 bytes should be more than sufficient to hold a 100-digit number
     // plus additional bytes for the decimal point, '-' sign etc.
@@ -328,6 +356,22 @@ binaryOperationTwoOutputs(Operation op, InputType x, InputType y, int &output) {
   }
 }
 
+template <typename InputType>
+cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
+ternaryOperationOneOutput(Operation op, InputType x, InputType y, InputType z) {
+  // For FMA function, we just need to compare with the mpfr_fma with the same
+  // precision as InputType.  Using higher precision as the intermediate results
+  // to compare might incorrectly fail due to double-rounding errors.
+  constexpr unsigned int prec = Precision<InputType>::value;
+  MPFRNumber inputX(x, prec), inputY(y, prec), inputZ(z, prec);
+  switch (op) {
+  case Operation::Fma:
+    return inputX.fma(inputY, inputZ);
+  default:
+    __builtin_unreachable();
+  }
+}
+
 template <typename T>
 void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
                                             testutils::StreamWrapper &OS) {
@@ -476,6 +520,48 @@ template void explainBinaryOperationOneOutputError<long double>(
     Operation, const BinaryInput<long double> &, long double,
     testutils::StreamWrapper &);
 
+template <typename T>
+void explainTernaryOperationOneOutputError(Operation op,
+                                           const TernaryInput<T> &input,
+                                           T libcResult,
+                                           testutils::StreamWrapper &OS) {
+  MPFRNumber mpfrX(input.x, Precision<T>::value);
+  MPFRNumber mpfrY(input.y, Precision<T>::value);
+  MPFRNumber mpfrZ(input.z, Precision<T>::value);
+  FPBits<T> xbits(input.x);
+  FPBits<T> ybits(input.y);
+  FPBits<T> zbits(input.z);
+  MPFRNumber mpfrResult =
+      ternaryOperationOneOutput(op, input.x, input.y, input.z);
+  MPFRNumber mpfrMatchValue(libcResult);
+
+  OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str()
+     << " z: " << mpfrZ.str() << '\n';
+  __llvm_libc::fputil::testing::describeValue("First input bits: ", input.x,
+                                              OS);
+  __llvm_libc::fputil::testing::describeValue("Second input bits: ", input.y,
+                                              OS);
+  __llvm_libc::fputil::testing::describeValue("Third input bits: ", input.z,
+                                              OS);
+
+  OS << "Libc result: " << mpfrMatchValue.str() << '\n'
+     << "MPFR result: " << mpfrResult.str() << '\n';
+  __llvm_libc::fputil::testing::describeValue(
+      "Libc floating point result bits: ", libcResult, OS);
+  __llvm_libc::fputil::testing::describeValue(
+      "              MPFR rounded bits: ", mpfrResult.as<T>(), OS);
+  OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult)) << '\n';
+}
+
+template void explainTernaryOperationOneOutputError<float>(
+    Operation, const TernaryInput<float> &, float, testutils::StreamWrapper &);
+template void explainTernaryOperationOneOutputError<double>(
+    Operation, const TernaryInput<double> &, double,
+    testutils::StreamWrapper &);
+template void explainTernaryOperationOneOutputError<long double>(
+    Operation, const TernaryInput<long double> &, long double,
+    testutils::StreamWrapper &);
+
 template <typename T>
 bool compareUnaryOperationSingleOutput(Operation op, T input, T libcResult,
                                        double ulpError) {
@@ -575,6 +661,27 @@ compareBinaryOperationOneOutput<double>(Operation, const BinaryInput<double> &,
 template bool compareBinaryOperationOneOutput<long double>(
     Operation, const BinaryInput<long double> &, long double, double);
 
+template <typename T>
+bool compareTernaryOperationOneOutput(Operation op,
+                                      const TernaryInput<T> &input,
+                                      T libcResult, double ulpError) {
+  MPFRNumber mpfrResult =
+      ternaryOperationOneOutput(op, input.x, input.y, input.z);
+  double ulp = mpfrResult.ulp(libcResult);
+
+  bool bitsAreEven = ((FPBits<T>(libcResult).bitsAsUInt() & 1) == 0);
+  return (ulp < ulpError) ||
+         ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
+}
+
+template bool
+compareTernaryOperationOneOutput<float>(Operation, const TernaryInput<float> &,
+                                        float, double);
+template bool compareTernaryOperationOneOutput<double>(
+    Operation, const TernaryInput<double> &, double, double);
+template bool compareTernaryOperationOneOutput<long double>(
+    Operation, const TernaryInput<long double> &, long double, double);
+
 static mpfr_rnd_t getMPFRRoundingMode(RoundingMode mode) {
   switch (mode) {
   case RoundingMode::Upward:

diff  --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h
index db20b1b2d8c8..17f8a09e80ba 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.h
+++ b/libc/utils/MPFRWrapper/MPFRUtils.h
@@ -57,8 +57,11 @@ enum class Operation : int {
   RemQuo, // The first output, the floating point output, is the remainder.
   EndBinaryOperationsTwoOutputs,
 
+  // Operations which take three floating point nubmers of the same type as
+  // input and produce a single floating point number of the same type as
+  // output.
   BeginTernaryOperationsSingleOuput,
-  // TODO: Add operations like fma.
+  Fma,
   EndTernaryOperationsSingleOutput,
 };
 
@@ -113,6 +116,11 @@ template <typename T>
 bool compareBinaryOperationOneOutput(Operation op, const BinaryInput<T> &input,
                                      T libcOutput, double t);
 
+template <typename T>
+bool compareTernaryOperationOneOutput(Operation op,
+                                      const TernaryInput<T> &input,
+                                      T libcOutput, double t);
+
 template <typename T>
 void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
                                             testutils::StreamWrapper &OS);
@@ -132,6 +140,12 @@ void explainBinaryOperationOneOutputError(Operation op,
                                           T matchValue,
                                           testutils::StreamWrapper &OS);
 
+template <typename T>
+void explainTernaryOperationOneOutputError(Operation op,
+                                           const TernaryInput<T> &input,
+                                           T matchValue,
+                                           testutils::StreamWrapper &OS);
+
 template <Operation op, typename InputType, typename OutputType>
 class MPFRMatcher : public testing::Matcher<OutputType> {
   InputType input;
@@ -174,7 +188,7 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
 
   template <typename T>
   static bool match(const TernaryInput<T> &in, T out, double tolerance) {
-    // TODO: Implement the comparision function and error reporter.
+    return compareTernaryOperationOneOutput(op, in, out, tolerance);
   }
 
   template <typename T>
@@ -199,6 +213,12 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
                            testutils::StreamWrapper &OS) {
     explainBinaryOperationOneOutputError(op, in, out, OS);
   }
+
+  template <typename T>
+  static void explainError(const TernaryInput<T> &in, T out,
+                           testutils::StreamWrapper &OS) {
+    explainTernaryOperationOneOutputError(op, in, out, OS);
+  }
 };
 
 } // namespace internal


        


More information about the libc-commits mailing list