[libc] [llvm] [libc][complex] Testing infra for MPC (PR #121261)

Nick Desaulniers via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 14 09:53:35 PST 2025


================
@@ -0,0 +1,356 @@
+//===-- Utils which wrap MPC ----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "MPCUtils.h"
+
+#include "src/__support/CPP/array.h"
+#include "src/__support/CPP/stringstream.h"
+#include "utils/MPFRWrapper/MPCommon.h"
+
+#include <stdint.h>
+
+#include "mpc.h"
+
+template <typename T> using FPBits = LIBC_NAMESPACE::fputil::FPBits<T>;
+
+namespace LIBC_NAMESPACE_DECL {
+namespace testing {
+namespace mpc {
+
+static inline cpp::string str(RoundingMode mode) {
+  switch (mode) {
+  case RoundingMode::Upward:
+    return "MPFR_RNDU";
+  case RoundingMode::Downward:
+    return "MPFR_RNDD";
+  case RoundingMode::TowardZero:
+    return "MPFR_RNDZ";
+  case RoundingMode::Nearest:
+    return "MPFR_RNDN";
+  }
+}
+
+class MPCNumber {
+private:
+  unsigned int precision;
+  mpc_t value;
+  mpc_rnd_t mpc_rounding;
+
+public:
+  MPCNumber(unsigned int p) : precision(p), mpc_rounding(MPC_RNDNN) {
+    mpc_init2(value, precision);
+  }
+
+  MPCNumber() : precision(256), mpc_rounding(MPC_RNDNN) {
+    mpc_init2(value, 256);
+  }
+
+  template <typename XType,
+            cpp::enable_if_t<cpp::is_same_v<_Complex float, XType>, bool> = 0>
+  MPCNumber(XType x,
+            unsigned int precision = mpfr::ExtraPrecision<float>::VALUE,
+            RoundingMode rnd = RoundingMode::Nearest)
+      : precision(precision),
+        mpc_rounding(MPC_RND(mpfr::get_mpfr_rounding_mode(rnd),
+                             mpfr::get_mpfr_rounding_mode(rnd))) {
+    mpc_init2(value, precision);
+    Complex<float> x_c = cpp::bit_cast<Complex<float>>(x);
+    mpfr_t real, imag;
+    mpfr_init2(real, precision);
+    mpfr_init2(imag, precision);
+    mpfr_set_flt(real, x_c.real, mpfr::get_mpfr_rounding_mode(rnd));
+    mpfr_set_flt(imag, x_c.imag, mpfr::get_mpfr_rounding_mode(rnd));
+    mpc_set_fr_fr(value, real, imag, mpc_rounding);
+    mpfr_clear(real);
+    mpfr_clear(imag);
+  }
+
+  template <typename XType,
+            cpp::enable_if_t<cpp::is_same_v<_Complex double, XType>, bool> = 0>
+  MPCNumber(XType x,
+            unsigned int precision = mpfr::ExtraPrecision<double>::VALUE,
+            RoundingMode rnd = RoundingMode::Nearest)
+      : precision(precision),
+        mpc_rounding(MPC_RND(mpfr::get_mpfr_rounding_mode(rnd),
+                             mpfr::get_mpfr_rounding_mode(rnd))) {
+    mpc_init2(value, precision);
+    Complex<double> x_c = cpp::bit_cast<Complex<double>>(x);
+    mpc_set_d_d(value, x_c.real, x_c.imag, mpc_rounding);
+  }
+
+  MPCNumber(const MPCNumber &other)
+      : precision(other.precision), mpc_rounding(other.mpc_rounding) {
+    mpc_init2(value, precision);
+    mpc_set(value, other.value, mpc_rounding);
+  }
+
+  MPCNumber &operator=(const MPCNumber &rhs) {
+    precision = rhs.precision;
+    mpc_rounding = rhs.mpc_rounding;
+    mpc_init2(value, precision);
+    mpc_set(value, rhs.value, mpc_rounding);
+    return *this;
+  }
+
+  MPCNumber(const mpc_t x, unsigned int p, mpc_rnd_t rnd)
+      : precision(p), mpc_rounding(rnd) {
+    mpc_init2(value, precision);
+    mpc_set(value, x, mpc_rounding);
+  }
+
+  ~MPCNumber() { mpc_clear(value); }
+
+  void getValue(mpc_t val) const { mpc_set(val, value, mpc_rounding); }
+
+  MPCNumber carg() const {
+    mpfr_t res;
+    mpc_t res_mpc;
+
+    mpfr_init2(res, precision);
+    mpc_init2(res_mpc, precision);
+
+    mpc_arg(res, value, MPC_RND_RE(mpc_rounding));
+    mpc_set_fr(res_mpc, res, mpc_rounding);
+
+    MPCNumber result(res_mpc, precision, mpc_rounding);
+
+    mpfr_clear(res);
+    mpc_clear(res_mpc);
+
+    return result;
+  }
+
+  MPCNumber cproj() const {
+    mpc_t res;
+
+    mpc_init2(res, precision);
+
+    mpc_proj(res, value, mpc_rounding);
+
+    MPCNumber result(res, precision, mpc_rounding);
+
+    mpc_clear(res);
+
+    return result;
+  }
+};
+
+namespace internal {
+
+template <typename InputType>
+cpp::enable_if_t<cpp::is_complex_v<InputType>, MPCNumber>
+unary_operation(Operation op, InputType input, unsigned int precision,
+                RoundingMode rounding) {
+  MPCNumber mpcInput(input, precision, rounding);
+  switch (op) {
+  case Operation::Carg:
+    return mpcInput.carg();
+  case Operation::Cproj:
+    return mpcInput.cproj();
+  default:
+    __builtin_unreachable();
+  }
+}
+
+template <typename InputType, typename OutputType>
+bool compare_unary_operation_single_output_same_type(Operation op,
+                                                     InputType input,
+                                                     OutputType libc_result,
+                                                     double ulp_tolerance,
+                                                     RoundingMode rounding) {
+
+  unsigned int precision =
+      mpfr::get_precision<make_real_t<InputType>>(ulp_tolerance);
+
+  MPCNumber mpc_result;
+  mpc_result = unary_operation(op, input, precision, rounding);
+
+  mpc_t mpc_result_val;
----------------
nickdesaulniers wrote:

Use `MPCNumber` here?

https://github.com/llvm/llvm-project/pull/121261


More information about the llvm-commits mailing list