[llvm-branch-commits] [libc] 4726bec - [libc] Add implementation of fmaf.
Tue Ly via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jan 6 14:19:33 PST 2021
Author: Tue Ly
Date: 2021-01-06T17:14:20-05:00
New Revision: 4726bec8f29bd535e2709b491d223d42bd20c120
URL: https://github.com/llvm/llvm-project/commit/4726bec8f29bd535e2709b491d223d42bd20c120
DIFF: https://github.com/llvm/llvm-project/commit/4726bec8f29bd535e2709b491d223d42bd20c120.diff
LOG: [libc] Add implementation of fmaf.
Differential Revision: https://reviews.llvm.org/D94018
Added:
libc/src/math/fmaf.cpp
libc/src/math/fmaf.h
libc/test/src/math/FmaTest.h
libc/test/src/math/fmaf_test.cpp
Modified:
libc/config/linux/aarch64/entrypoints.txt
libc/config/linux/x86_64/entrypoints.txt
libc/spec/stdc.td
libc/src/math/CMakeLists.txt
libc/test/src/math/CMakeLists.txt
libc/utils/FPUtil/FPBits.h
libc/utils/MPFRWrapper/MPFRUtils.cpp
libc/utils/MPFRWrapper/MPFRUtils.h
Removed:
################################################################################
diff --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt
index b9042625e666..0db8c4b39caa 100644
--- a/libc/config/linux/aarch64/entrypoints.txt
+++ b/libc/config/linux/aarch64/entrypoints.txt
@@ -65,6 +65,7 @@ set(TARGET_LIBM_ENTRYPOINTS
libc.src.math.floor
libc.src.math.floorf
libc.src.math.floorl
+ libc.src.math.fmaf
libc.src.math.fmax
libc.src.math.fmaxf
libc.src.math.fmaxl
diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt
index a34c59646149..a80a8b4f105b 100644
--- a/libc/config/linux/x86_64/entrypoints.txt
+++ b/libc/config/linux/x86_64/entrypoints.txt
@@ -106,6 +106,7 @@ set(TARGET_LIBM_ENTRYPOINTS
libc.src.math.floor
libc.src.math.floorf
libc.src.math.floorl
+ libc.src.math.fmaf
libc.src.math.fmin
libc.src.math.fminf
libc.src.math.fminl
diff --git a/libc/spec/stdc.td b/libc/spec/stdc.td
index 41f6083a2336..e89d16633ae3 100644
--- a/libc/spec/stdc.td
+++ b/libc/spec/stdc.td
@@ -322,6 +322,8 @@ def StdC : StandardSpec<"stdc"> {
FunctionSpec<"fmaxf", RetValSpec<FloatType>, [ArgSpec<FloatType>, ArgSpec<FloatType>]>,
FunctionSpec<"fmaxl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>, ArgSpec<LongDoubleType>]>,
+ FunctionSpec<"fmaf", RetValSpec<FloatType>, [ArgSpec<FloatType>, ArgSpec<FloatType>, ArgSpec<FloatType>]>,
+
FunctionSpec<"frexp", RetValSpec<DoubleType>, [ArgSpec<DoubleType>, ArgSpec<IntPtr>]>,
FunctionSpec<"frexpf", RetValSpec<FloatType>, [ArgSpec<FloatType>, ArgSpec<IntPtr>]>,
FunctionSpec<"frexpl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>, ArgSpec<IntPtr>]>,
diff --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt
index 15a2e4645edd..34b7dfcd4306 100644
--- a/libc/src/math/CMakeLists.txt
+++ b/libc/src/math/CMakeLists.txt
@@ -978,3 +978,14 @@ add_entrypoint_object(
-O2
)
+add_entrypoint_object(
+ fmaf
+ SRCS
+ fmaf.cpp
+ HDRS
+ fmaf.h
+ DEPENDS
+ libc.utils.FPUtil.fputil
+ COMPILE_OPTIONS
+ -O2
+)
diff --git a/libc/src/math/fmaf.cpp b/libc/src/math/fmaf.cpp
new file mode 100644
index 000000000000..1860d887d630
--- /dev/null
+++ b/libc/src/math/fmaf.cpp
@@ -0,0 +1,64 @@
+//===-- Implementation of fmaf function -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/__support/common.h"
+
+#include "utils/FPUtil/FEnv.h"
+#include "utils/FPUtil/FPBits.h"
+
+namespace __llvm_libc {
+
+float LLVM_LIBC_ENTRYPOINT(fmaf)(float x, float y, float z) {
+ // Product is exact.
+ double prod = static_cast<double>(x) * static_cast<double>(y);
+ double z_d = static_cast<double>(z);
+ double sum = prod + z_d;
+ fputil::FPBits<double> bit_prod(prod), bitz(z_d), bit_sum(sum);
+
+ if (!(bit_sum.isInfOrNaN() || bit_sum.isZero())) {
+ // Since the sum is computed in double precision, rounding might happen
+ // (for instance, when bitz.exponent > bit_prod.exponent + 5, or
+ // bit_prod.exponent > bitz.exponent + 40). In that case, when we round
+ // the sum back to float, double rounding error might occur.
+ // A concrete example of this phenomenon is as follows:
+ // x = y = 1 + 2^(-12), z = 2^(-53)
+ // The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53)
+ // So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23)
+ // On the other hand, with the default rounding mode,
+ // double(x*y + z) = 1 + 2^(-11) + 2^(-24)
+ // and casting again to float gives us:
+ // float(double(x*y + z)) = 1 + 2^(-11).
+ //
+ // In order to correct this possible double rounding error, first we use
+ // Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly,
+ // assuming the (default) rounding mode is round-to-the-nearest,
+ // tie-to-even. Moreover, t satisfies the condition that t < eps(sum),
+ // i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding
+ // occurs when computing the sum, we just need to use t to adjust (any) last
+ // bit of sum, so that the sticky bits used when rounding sum to float are
+ // correct (when it matters).
+ fputil::FPBits<double> t(
+ (bit_prod.exponent >= bitz.exponent)
+ ? ((static_cast<double>(bit_sum) - bit_prod) - bitz)
+ : ((static_cast<double>(bit_sum) - bitz) - bit_prod));
+
+ // Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are
+ // zero.
+ if (!t.isZero() && ((bit_sum.mantissa & 0xfff'ffffULL) == 0)) {
+ if (bit_sum.sign != t.sign) {
+ ++bit_sum.mantissa;
+ } else if (bit_sum.mantissa) {
+ --bit_sum.mantissa;
+ }
+ }
+ }
+
+ return static_cast<float>(static_cast<double>(bit_sum));
+}
+
+} // namespace __llvm_libc
diff --git a/libc/src/math/fmaf.h b/libc/src/math/fmaf.h
new file mode 100644
index 000000000000..48fbb65d6650
--- /dev/null
+++ b/libc/src/math/fmaf.h
@@ -0,0 +1,18 @@
+//===-- Implementation header for fmaf --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_FMAF_H
+#define LLVM_LIBC_SRC_MATH_FMAF_H
+
+namespace __llvm_libc {
+
+float fmaf(float x, float y, float z);
+
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_SRC_MATH_FMAF_H
diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt
index 3f4d8c76d7c0..48f8902e3f59 100644
--- a/libc/test/src/math/CMakeLists.txt
+++ b/libc/test/src/math/CMakeLists.txt
@@ -1049,3 +1049,16 @@ add_fp_unittest(
libc.src.math.nextafterl
libc.utils.FPUtil.fputil
)
+
+add_fp_unittest(
+ fmaf_test
+ NEED_MPFR
+ SUITE
+ libc_math_unittests
+ SRCS
+ fmaf_test.cpp
+ DEPENDS
+ libc.include.math
+ libc.src.math.fmaf
+ libc.utils.FPUtil.fputil
+)
diff --git a/libc/test/src/math/FmaTest.h b/libc/test/src/math/FmaTest.h
new file mode 100644
index 000000000000..c39c4ad0f1da
--- /dev/null
+++ b/libc/test/src/math/FmaTest.h
@@ -0,0 +1,94 @@
+//===-- Utility class to test
diff erent flavors of fma --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_TEST_SRC_MATH_FMATEST_H
+#define LLVM_LIBC_TEST_SRC_MATH_FMATEST_H
+
+#include "utils/FPUtil/FPBits.h"
+#include "utils/FPUtil/TestHelpers.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+#include "utils/UnitTest/Test.h"
+
+#include <random>
+
+namespace mpfr = __llvm_libc::testing::mpfr;
+
+template <typename T>
+class FmaTestTemplate : public __llvm_libc::testing::Test {
+private:
+ using Func = T (*)(T, T, T);
+ using FPBits = __llvm_libc::fputil::FPBits<T>;
+ using UIntType = typename FPBits::UIntType;
+ const T nan = __llvm_libc::fputil::FPBits<T>::buildNaN(1);
+ const T inf = __llvm_libc::fputil::FPBits<T>::inf();
+ const T negInf = __llvm_libc::fputil::FPBits<T>::negInf();
+ const T zero = __llvm_libc::fputil::FPBits<T>::zero();
+ const T negZero = __llvm_libc::fputil::FPBits<T>::negZero();
+
+ UIntType getRandomBitPattern() {
+ UIntType bits{0};
+ for (size_t i = 0; i < sizeof(UIntType) / 2; ++i) {
+ bits = (bits << 2) + static_cast<uint16_t>(std::rand());
+ }
+ return bits;
+ }
+
+public:
+ void testSpecialNumbers(Func func) {
+ EXPECT_FP_EQ(func(zero, zero, zero), zero);
+ EXPECT_FP_EQ(func(zero, negZero, negZero), negZero);
+ EXPECT_FP_EQ(func(inf, inf, zero), inf);
+ EXPECT_FP_EQ(func(negInf, inf, negInf), negInf);
+ EXPECT_FP_EQ(func(inf, zero, zero), nan);
+ EXPECT_FP_EQ(func(inf, negInf, inf), nan);
+ EXPECT_FP_EQ(func(nan, zero, inf), nan);
+ EXPECT_FP_EQ(func(inf, negInf, nan), nan);
+
+ // Test underflow rounding up.
+ EXPECT_FP_EQ(func(T(0.5), FPBits(FPBits::minSubnormal),
+ FPBits(FPBits::minSubnormal)),
+ FPBits(UIntType(2)));
+ // Test underflow rounding down.
+ FPBits v(FPBits::minNormal + UIntType(1));
+ EXPECT_FP_EQ(
+ func(T(1) / T(FPBits::minNormal << 1), v, FPBits(FPBits::minNormal)),
+ v);
+ // Test overflow.
+ FPBits z(FPBits::maxNormal);
+ EXPECT_FP_EQ(func(T(1.75), z, -z), T(0.75) * z);
+ }
+
+ void testSubnormalRange(Func func) {
+ constexpr UIntType count = 1000001;
+ constexpr UIntType step =
+ (FPBits::maxSubnormal - FPBits::minSubnormal) / count;
+ for (UIntType v = FPBits::minSubnormal, w = FPBits::maxSubnormal;
+ v <= FPBits::maxSubnormal && w >= FPBits::minSubnormal;
+ v += step, w -= step) {
+ T x = FPBits(getRandomBitPattern()), y = FPBits(v), z = FPBits(w);
+ T result = func(x, y, z);
+ mpfr::TernaryInput<T> input{x, y, z};
+ ASSERT_MPFR_MATCH(mpfr::Operation::Fma, input, result, 0.5);
+ }
+ }
+
+ void testNormalRange(Func func) {
+ constexpr UIntType count = 1000001;
+ constexpr UIntType step = (FPBits::maxNormal - FPBits::minNormal) / count;
+ for (UIntType v = FPBits::minNormal, w = FPBits::maxNormal;
+ v <= FPBits::maxNormal && w >= FPBits::minNormal;
+ v += step, w -= step) {
+ T x = FPBits(v), y = FPBits(w), z = FPBits(getRandomBitPattern());
+ T result = func(x, y, z);
+ mpfr::TernaryInput<T> input{x, y, z};
+ ASSERT_MPFR_MATCH(mpfr::Operation::Fma, input, result, 0.5);
+ }
+ }
+};
+
+#endif // LLVM_LIBC_TEST_SRC_MATH_FMATEST_H
diff --git a/libc/test/src/math/fmaf_test.cpp b/libc/test/src/math/fmaf_test.cpp
new file mode 100644
index 000000000000..96d3ca7cf113
--- /dev/null
+++ b/libc/test/src/math/fmaf_test.cpp
@@ -0,0 +1,19 @@
+//===-- Unittests for fmaf ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "FmaTest.h"
+
+#include "src/math/fmaf.h"
+
+using FmaTest = FmaTestTemplate<float>;
+
+TEST_F(FmaTest, SpecialNumbers) { testSpecialNumbers(&__llvm_libc::fmaf); }
+
+TEST_F(FmaTest, SubnormalRange) { testSubnormalRange(&__llvm_libc::fmaf); }
+
+TEST_F(FmaTest, NormalRange) { testNormalRange(&__llvm_libc::fmaf); }
diff --git a/libc/utils/FPUtil/FPBits.h b/libc/utils/FPUtil/FPBits.h
index 89bdd92669b8..02c0dc880579 100644
--- a/libc/utils/FPUtil/FPBits.h
+++ b/libc/utils/FPUtil/FPBits.h
@@ -84,7 +84,10 @@ template <typename T> struct __attribute__((packed)) FPBits {
// We don't want accidental type promotions/conversions so we require exact
// type match.
template <typename XType,
- cpp::EnableIfType<cpp::IsSame<T, XType>::Value, int> = 0>
+ cpp::EnableIfType<cpp::IsSame<T, XType>::Value ||
+ (cpp::IsIntegral<XType>::Value &&
+ (sizeof(XType) == sizeof(UIntType))),
+ int> = 0>
explicit FPBits(XType x) {
*this = *reinterpret_cast<FPBits<T> *>(&x);
}
@@ -106,13 +109,6 @@ template <typename T> struct __attribute__((packed)) FPBits {
// the potential software implementations of UIntType will not slow real
// code.
- template <typename XType,
- cpp::EnableIfType<cpp::IsSame<UIntType, XType>::Value, int> = 0>
- explicit FPBits<long double>(XType x) {
- // The last 4 bytes of v are ignored in case of i386.
- *this = *reinterpret_cast<FPBits<T> *>(&x);
- }
-
UIntType bitsAsUInt() const {
return *reinterpret_cast<const UIntType *>(this);
}
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp
index 3c66ccfe5b38..7cc80a281c6e 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp
@@ -35,48 +35,69 @@ namespace __llvm_libc {
namespace testing {
namespace mpfr {
+template <typename T> struct Precision;
+
+template <> struct Precision<float> {
+ static constexpr unsigned int value = 24;
+};
+
+template <> struct Precision<double> {
+ static constexpr unsigned int value = 53;
+};
+
+#if !(defined(__x86_64__) || defined(__i386__))
+template <> struct Precision<long double> {
+ static constexpr unsigned int value = 64;
+};
+#else
+template <> struct Precision<long double> {
+ static constexpr unsigned int value = 113;
+};
+#endif
+
class MPFRNumber {
// A precision value which allows sufficiently large additional
// precision even compared to quad-precision floating point values.
- static constexpr unsigned int mpfrPrecision = 128;
+ unsigned int mpfrPrecision;
mpfr_t value;
public:
- MPFRNumber() { mpfr_init2(value, mpfrPrecision); }
+ MPFRNumber() : mpfrPrecision(128) { mpfr_init2(value, mpfrPrecision); }
// We use explicit EnableIf specializations to disallow implicit
// conversions. Implicit conversions can potentially lead to loss of
// precision.
template <typename XType,
cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0>
- explicit MPFRNumber(XType x) {
+ explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_flt(value, x, MPFR_RNDN);
}
template <typename XType,
cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0>
- explicit MPFRNumber(XType x) {
+ explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_d(value, x, MPFR_RNDN);
}
template <typename XType,
cpp::EnableIfType<cpp::IsSame<long double, XType>::Value, int> = 0>
- explicit MPFRNumber(XType x) {
+ explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_ld(value, x, MPFR_RNDN);
}
template <typename XType,
cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0>
- explicit MPFRNumber(XType x) {
+ explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_sj(value, x, MPFR_RNDN);
}
- MPFRNumber(const MPFRNumber &other) {
+ MPFRNumber(const MPFRNumber &other) : mpfrPrecision(other.mpfrPrecision) {
+ mpfr_init2(value, mpfrPrecision);
mpfr_set(value, other.value, MPFR_RNDN);
}
@@ -85,6 +106,7 @@ class MPFRNumber {
}
MPFRNumber &operator=(const MPFRNumber &rhs) {
+ mpfrPrecision = rhs.mpfrPrecision;
mpfr_set(value, rhs.value, MPFR_RNDN);
return *this;
}
@@ -193,6 +215,12 @@ class MPFRNumber {
return result;
}
+ MPFRNumber fma(const MPFRNumber &b, const MPFRNumber &c) {
+ MPFRNumber result(*this);
+ mpfr_fma(result.value, value, b.value, c.value, MPFR_RNDN);
+ return result;
+ }
+
std::string str() const {
// 200 bytes should be more than sufficient to hold a 100-digit number
// plus additional bytes for the decimal point, '-' sign etc.
@@ -328,6 +356,22 @@ binaryOperationTwoOutputs(Operation op, InputType x, InputType y, int &output) {
}
}
+template <typename InputType>
+cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
+ternaryOperationOneOutput(Operation op, InputType x, InputType y, InputType z) {
+ // For FMA function, we just need to compare with the mpfr_fma with the same
+ // precision as InputType. Using higher precision as the intermediate results
+ // to compare might incorrectly fail due to double-rounding errors.
+ constexpr unsigned int prec = Precision<InputType>::value;
+ MPFRNumber inputX(x, prec), inputY(y, prec), inputZ(z, prec);
+ switch (op) {
+ case Operation::Fma:
+ return inputX.fma(inputY, inputZ);
+ default:
+ __builtin_unreachable();
+ }
+}
+
template <typename T>
void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
testutils::StreamWrapper &OS) {
@@ -476,6 +520,48 @@ template void explainBinaryOperationOneOutputError<long double>(
Operation, const BinaryInput<long double> &, long double,
testutils::StreamWrapper &);
+template <typename T>
+void explainTernaryOperationOneOutputError(Operation op,
+ const TernaryInput<T> &input,
+ T libcResult,
+ testutils::StreamWrapper &OS) {
+ MPFRNumber mpfrX(input.x, Precision<T>::value);
+ MPFRNumber mpfrY(input.y, Precision<T>::value);
+ MPFRNumber mpfrZ(input.z, Precision<T>::value);
+ FPBits<T> xbits(input.x);
+ FPBits<T> ybits(input.y);
+ FPBits<T> zbits(input.z);
+ MPFRNumber mpfrResult =
+ ternaryOperationOneOutput(op, input.x, input.y, input.z);
+ MPFRNumber mpfrMatchValue(libcResult);
+
+ OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str()
+ << " z: " << mpfrZ.str() << '\n';
+ __llvm_libc::fputil::testing::describeValue("First input bits: ", input.x,
+ OS);
+ __llvm_libc::fputil::testing::describeValue("Second input bits: ", input.y,
+ OS);
+ __llvm_libc::fputil::testing::describeValue("Third input bits: ", input.z,
+ OS);
+
+ OS << "Libc result: " << mpfrMatchValue.str() << '\n'
+ << "MPFR result: " << mpfrResult.str() << '\n';
+ __llvm_libc::fputil::testing::describeValue(
+ "Libc floating point result bits: ", libcResult, OS);
+ __llvm_libc::fputil::testing::describeValue(
+ " MPFR rounded bits: ", mpfrResult.as<T>(), OS);
+ OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult)) << '\n';
+}
+
+template void explainTernaryOperationOneOutputError<float>(
+ Operation, const TernaryInput<float> &, float, testutils::StreamWrapper &);
+template void explainTernaryOperationOneOutputError<double>(
+ Operation, const TernaryInput<double> &, double,
+ testutils::StreamWrapper &);
+template void explainTernaryOperationOneOutputError<long double>(
+ Operation, const TernaryInput<long double> &, long double,
+ testutils::StreamWrapper &);
+
template <typename T>
bool compareUnaryOperationSingleOutput(Operation op, T input, T libcResult,
double ulpError) {
@@ -575,6 +661,27 @@ compareBinaryOperationOneOutput<double>(Operation, const BinaryInput<double> &,
template bool compareBinaryOperationOneOutput<long double>(
Operation, const BinaryInput<long double> &, long double, double);
+template <typename T>
+bool compareTernaryOperationOneOutput(Operation op,
+ const TernaryInput<T> &input,
+ T libcResult, double ulpError) {
+ MPFRNumber mpfrResult =
+ ternaryOperationOneOutput(op, input.x, input.y, input.z);
+ double ulp = mpfrResult.ulp(libcResult);
+
+ bool bitsAreEven = ((FPBits<T>(libcResult).bitsAsUInt() & 1) == 0);
+ return (ulp < ulpError) ||
+ ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
+}
+
+template bool
+compareTernaryOperationOneOutput<float>(Operation, const TernaryInput<float> &,
+ float, double);
+template bool compareTernaryOperationOneOutput<double>(
+ Operation, const TernaryInput<double> &, double, double);
+template bool compareTernaryOperationOneOutput<long double>(
+ Operation, const TernaryInput<long double> &, long double, double);
+
static mpfr_rnd_t getMPFRRoundingMode(RoundingMode mode) {
switch (mode) {
case RoundingMode::Upward:
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h
index db20b1b2d8c8..17f8a09e80ba 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.h
+++ b/libc/utils/MPFRWrapper/MPFRUtils.h
@@ -57,8 +57,11 @@ enum class Operation : int {
RemQuo, // The first output, the floating point output, is the remainder.
EndBinaryOperationsTwoOutputs,
+ // Operations which take three floating point nubmers of the same type as
+ // input and produce a single floating point number of the same type as
+ // output.
BeginTernaryOperationsSingleOuput,
- // TODO: Add operations like fma.
+ Fma,
EndTernaryOperationsSingleOutput,
};
@@ -113,6 +116,11 @@ template <typename T>
bool compareBinaryOperationOneOutput(Operation op, const BinaryInput<T> &input,
T libcOutput, double t);
+template <typename T>
+bool compareTernaryOperationOneOutput(Operation op,
+ const TernaryInput<T> &input,
+ T libcOutput, double t);
+
template <typename T>
void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
testutils::StreamWrapper &OS);
@@ -132,6 +140,12 @@ void explainBinaryOperationOneOutputError(Operation op,
T matchValue,
testutils::StreamWrapper &OS);
+template <typename T>
+void explainTernaryOperationOneOutputError(Operation op,
+ const TernaryInput<T> &input,
+ T matchValue,
+ testutils::StreamWrapper &OS);
+
template <Operation op, typename InputType, typename OutputType>
class MPFRMatcher : public testing::Matcher<OutputType> {
InputType input;
@@ -174,7 +188,7 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
template <typename T>
static bool match(const TernaryInput<T> &in, T out, double tolerance) {
- // TODO: Implement the comparision function and error reporter.
+ return compareTernaryOperationOneOutput(op, in, out, tolerance);
}
template <typename T>
@@ -199,6 +213,12 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
testutils::StreamWrapper &OS) {
explainBinaryOperationOneOutputError(op, in, out, OS);
}
+
+ template <typename T>
+ static void explainError(const TernaryInput<T> &in, T out,
+ testutils::StreamWrapper &OS) {
+ explainTernaryOperationOneOutputError(op, in, out, OS);
+ }
};
} // namespace internal
More information about the llvm-branch-commits
mailing list