[libc] [llvm] [libc][complex] Testing infra for MPC (PR #121261)

via llvm-commits llvm-commits at lists.llvm.org
Sat Jan 11 07:10:19 PST 2025


================
@@ -0,0 +1,271 @@
+//===-- MPCUtils.h ----------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H
+#define LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H
+
+#include "src/__support/CPP/type_traits.h"
+#include "src/__support/CPP/type_traits/is_complex.h"
+#include "src/__support/complex_type.h"
+#include "src/__support/macros/config.h"
+#include "src/__support/macros/properties/complex_types.h"
+#include "src/__support/macros/properties/types.h"
+#include "test/UnitTest/RoundingModeUtils.h"
+#include "test/UnitTest/Test.h"
+
+#include <stdint.h>
+
+namespace LIBC_NAMESPACE_DECL {
+namespace testing {
+namespace mpc {
+
+enum class Operation {
+  // Operations which take a single complex floating point number as input
+  // and produce a single floating point number as output which has the same
+  // floating point type as the real/imaginary part of the input.
+  BeginUnaryOperationsSingleOutputDifferentOutputType,
+  Carg,
+  Cabs,
+  EndUnaryOperationsSingleOutputDifferentOutputType,
+
+  // Operations which take a single complex floating point number as input
+  // and produce a single complex floating point number of the same kind
+  // as output.
+  BeginUnaryOperationsSingleOutputSameOutputType,
+  Cproj,
+  Csqrt,
+  Clog,
+  Cexp,
+  Csinh,
+  Ccosh,
+  Ctanh,
+  Casinh,
+  Cacosh,
+  Catanh,
+  Csin,
+  Ccos,
+  Ctan,
+  Casin,
+  Cacos,
+  Catan,
+  EndUnaryOperationsSingleOutputSameOutputType,
+
+  // Operations which take two complex floating point numbers as input
+  // and produce a single complex floating point number of the same kind
+  // as output.
+  BeginBinaryOperationsSingleOutput,
+  Cpow,
+  EndBinaryOperationsSingleOutput,
+};
+
+using LIBC_NAMESPACE::fputil::testing::RoundingMode;
+
+template <typename T> struct BinaryInput {
+  static_assert(LIBC_NAMESPACE::cpp::is_complex_v<T>,
+                "Template parameter of BinaryInput must be a complex floating "
+                "point type.");
+
+  using Type = T;
+  T x, y;
+};
+
+namespace internal {
+
+template <typename InputType, typename OutputType>
+bool compare_unary_operation_single_output_same_type(Operation op,
+                                                     InputType input,
+                                                     OutputType libc_output,
+                                                     double ulp_tolerance,
+                                                     RoundingMode rounding);
+
+template <typename InputType, typename OutputType>
+bool compare_unary_operation_single_output_different_type(
+    Operation op, InputType input, OutputType libc_output, double ulp_tolerance,
+    RoundingMode rounding);
+
+template <typename InputType, typename OutputType>
+bool compare_binary_operation_one_output(Operation op,
+                                         const BinaryInput<InputType> &input,
+                                         OutputType libc_output,
+                                         double ulp_tolerance,
+                                         RoundingMode rounding);
+
+template <typename InputType, typename OutputType>
+void explain_unary_operation_single_output_same_type_error(
+    Operation op, InputType input, OutputType match_value, double ulp_tolerance,
+    RoundingMode rounding);
+
+template <typename InputType, typename OutputType>
+void explain_unary_operation_single_output_different_type_error(
+    Operation op, InputType input, OutputType match_value, double ulp_tolerance,
+    RoundingMode rounding);
+
+template <typename InputType, typename OutputType>
+void explain_binary_operation_one_output_error(
+    Operation op, const BinaryInput<InputType> &input, OutputType match_value,
+    double ulp_tolerance, RoundingMode rounding);
+
+template <Operation op, typename InputType, typename OutputType>
+class MPCMatcher : public testing::Matcher<OutputType> {
+private:
+  InputType input;
+  OutputType match_value;
+  double ulp_tolerance;
+  RoundingMode rounding;
+
+public:
+  MPCMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding)
+      : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {}
+
+  bool match(OutputType libcResult) {
+    match_value = libcResult;
+    return match(input, match_value);
+  }
+
+  void explainError() override { // NOLINT
+    explain_error(input, match_value);
+  }
+
+private:
+  template <typename InType, typename OutType>
+  bool match(InType in, OutType out) {
+    if (cpp::is_same_v<InType, OutType>) {
+      return compare_unary_operation_single_output_same_type(
+          op, in, out, ulp_tolerance, rounding);
+    } else {
+      return compare_unary_operation_single_output_different_type(
+          op, in, out, ulp_tolerance, rounding);
+    }
+  }
+
+  template <typename T, typename U>
+  bool match(const BinaryInput<T> &in, U out) {
+    return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
+                                               rounding);
+  }
+
+  template <typename InType, typename OutType>
+  void explain_error(InType in, OutType out) {
+    if (cpp::is_same_v<InType, OutType>) {
+      explain_unary_operation_single_output_same_type_error(
+          op, in, out, ulp_tolerance, rounding);
+    } else {
+      explain_unary_operation_single_output_different_type_error(
+          op, in, out, ulp_tolerance, rounding);
+    }
+  }
+
+  template <typename T, typename U>
+  void explain_error(const BinaryInput<T> &in, U out) {
+    explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
+                                              rounding);
+  }
+};
+
+} // namespace internal
+
+// Return true if the input and ouput types for the operation op are valid
+// types.
+template <Operation op, typename InputType, typename OutputType>
+constexpr bool is_valid_operation() {
+  return (Operation::BeginBinaryOperationsSingleOutput < op &&
+          op < Operation::EndBinaryOperationsSingleOutput &&
+          cpp::is_complex_type_same<InputType, OutputType> &&
+          cpp::is_complex_v<InputType>) ||
+         (Operation::BeginUnaryOperationsSingleOutputSameOutputType < op &&
+          op < Operation::EndUnaryOperationsSingleOutputSameOutputType &&
+          cpp::is_complex_type_same<InputType, OutputType> &&
+          cpp::is_complex_v<InputType>) ||
+         (Operation::BeginUnaryOperationsSingleOutputDifferentOutputType < op &&
+          op < Operation::EndUnaryOperationsSingleOutputDifferentOutputType &&
+          cpp::is_same_v<make_real_t<InputType>, OutputType> &&
+          cpp::is_complex_v<InputType>);
+}
+
+template <Operation op, typename InputType, typename OutputType>
+cpp::enable_if_t<is_valid_operation<op, InputType, OutputType>(),
+                 internal::MPCMatcher<op, InputType, OutputType>>
+get_mpc_matcher(InputType input, [[maybe_unused]] OutputType output,
+                double ulp_tolerance, RoundingMode rounding) {
+  return internal::MPCMatcher<op, InputType, OutputType>(input, ulp_tolerance,
+                                                         rounding);
+}
+
+} // namespace mpc
+} // namespace testing
+} // namespace LIBC_NAMESPACE_DECL
+
+#define EXPECT_MPC_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)        \
+  EXPECT_THAT(match_value,                                                     \
+              LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>(               \
+                  input, match_value, ulp_tolerance,                           \
+                  LIBC_NAMESPACE::fputil::testing::RoundingMode::Nearest))
+
+#define EXPECT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,       \
+                                  rounding)                                    \
+  EXPECT_THAT(match_value, LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>(  \
+                               input, match_value, ulp_tolerance, rounding))
+
+#define EXPECT_MPC_MATCH_ALL_ROUNDING_HELPER(i, op, input, match_value,        \
+                                             ulp_tolerance, rounding)          \
+  {                                                                            \
+    MPCRND::ForceRoundingMode __r##i(rounding);                                \
+    if (__r##i.success) {                                                      \
+      EXPECT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,         \
+                                rounding);                                     \
+    }                                                                          \
+  }
+
+#define EXPECT_MPC_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance)   \
+  {                                                                            \
+    namespace MPCRND = LIBC_NAMESPACE::fputil::testing;                        \
+    for (int i = 0; i < 4; i++) {                                              \
+      MPCRND::RoundingMode r_mode = static_cast<MPCRND::RoundingMode>(i);      \
+      EXPECT_MPC_MATCH_ALL_ROUNDING_HELPER(i, op, input, match_value,          \
+                                           ulp_tolerance, r_mode);             \
+    }                                                                          \
+  }
+
+#define TEST_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,         \
+                                rounding)                                      \
+  LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>(input, match_value,        \
+                                                    ulp_tolerance, rounding)   \
+      .match(match_value)
+
+#define ASSERT_MPC_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)        \
+  ASSERT_THAT(match_value,                                                     \
+              LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>(               \
+                  input, match_value, ulp_tolerance,                           \
+                  LIBC_NAMESPACE::fputil::testing::RoundingMode::Nearest))
+
+#define ASSERT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,       \
+                                  rounding)                                    \
+  ASSERT_THAT(match_value, LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>(  \
+                               input, match_value, ulp_tolerance, rounding))
+
+#define ASSERT_MPC_MATCH_ALL_ROUNDING_HELPER(i, op, input, match_value,        \
----------------
lntue wrote:

I don't think the first `i` argument is needed here.

https://github.com/llvm/llvm-project/pull/121261


More information about the llvm-commits mailing list