[flang-commits] [flang] a8d4335 - [flang][AIX] Handle more trig functions with complex argument to have consistent results in folding (#124203)

via flang-commits flang-commits at lists.llvm.org
Wed Jan 29 17:26:43 PST 2025


Author: Kelvin Li
Date: 2025-01-29T20:26:39-05:00
New Revision: a8d4335ee08b18068eb913528b963590f510d0b4

URL: https://github.com/llvm/llvm-project/commit/a8d4335ee08b18068eb913528b963590f510d0b4
DIFF: https://github.com/llvm/llvm-project/commit/a8d4335ee08b18068eb913528b963590f510d0b4.diff

LOG: [flang][AIX] Handle more trig functions with complex argument to have consistent results in folding (#124203)

This patch extends 71d4f34 to all trig functions that take complex
arguments. On AIX, the `libm` routines are called in compile time
folding instead of the STL routines.

Added: 
    

Modified: 
    flang/lib/Evaluate/intrinsics-library.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Evaluate/intrinsics-library.cpp b/flang/lib/Evaluate/intrinsics-library.cpp
index c1b270f518c0e0..d2c1be65dca448 100644
--- a/flang/lib/Evaluate/intrinsics-library.cpp
+++ b/flang/lib/Evaluate/intrinsics-library.cpp
@@ -260,6 +260,16 @@ struct HostRuntimeLibrary<HostT, LibraryVersion::Libm> {
   static_assert(map.Verify(), "map must be sorted");
 };
 
+#define COMPLEX_SIGNATURES(HOST_T) \
+  using F = FuncPointer<std::complex<HOST_T>, const std::complex<HOST_T> &>; \
+  using F2 = FuncPointer<std::complex<HOST_T>, const std::complex<HOST_T> &, \
+      const std::complex<HOST_T> &>; \
+  using F2A = FuncPointer<std::complex<HOST_T>, const HOST_T &, \
+      const std::complex<HOST_T> &>; \
+  using F2B = FuncPointer<std::complex<HOST_T>, const std::complex<HOST_T> &, \
+      const HOST_T &>;
+
+#ifndef _AIX
 // Helpers to map complex std::pow whose resolution in F2{std::pow} is
 // ambiguous as of clang++ 20.
 template <typename HostT>
@@ -267,98 +277,24 @@ static std::complex<HostT> StdPowF2(
     const std::complex<HostT> &x, const std::complex<HostT> &y) {
   return std::pow(x, y);
 }
+
 template <typename HostT>
 static std::complex<HostT> StdPowF2A(
     const HostT &x, const std::complex<HostT> &y) {
   return std::pow(x, y);
 }
+
 template <typename HostT>
 static std::complex<HostT> StdPowF2B(
     const std::complex<HostT> &x, const HostT &y) {
   return std::pow(x, y);
 }
 
-#ifdef _AIX
-#ifdef __clang_major__
-#pragma clang diagnostic ignored "-Wc99-extensions"
-#endif
-
-extern "C" {
-float _Complex cacosf(float _Complex);
-double _Complex cacos(double _Complex);
-float _Complex csqrtf(float _Complex);
-double _Complex csqrt(double _Complex);
-}
-
-enum CRI { Real, Imag };
-template <typename TR, typename TA> static TR &reIm(TA &x, CRI n) {
-  return reinterpret_cast<TR(&)[2]>(x)[n];
-}
-template <typename TR, typename T> static TR CppToC(const std::complex<T> &x) {
-  TR r;
-  reIm<T, TR>(r, CRI::Real) = x.real();
-  reIm<T, TR>(r, CRI::Imag) = x.imag();
-  return r;
-}
-template <typename T, typename TA> static std::complex<T> CToCpp(const TA &x) {
-  TA &z{const_cast<TA &>(x)};
-  return std::complex<T>(reIm<T, TA>(z, CRI::Real), reIm<T, TA>(z, CRI::Imag));
-}
-#endif
-
-template <typename HostT>
-static std::complex<HostT> CSqrt(const std::complex<HostT> &x) {
-  std::complex<HostT> res;
-#ifdef _AIX
-  // On AIX, the implementation of csqrt[f] and std::sqrt is 
diff erent,
-  // use csqrt[f] in folding.
-  if constexpr (std::is_same_v<HostT, float>) {
-    float _Complex r{csqrtf(CppToC<float _Complex, float>(x))};
-    res = CToCpp<float, float _Complex>(r);
-  } else if constexpr (std::is_same_v<HostT, double>) {
-    double _Complex r{csqrt(CppToC<double _Complex, double>(x))};
-    res = CToCpp<double, double _Complex>(r);
-  } else {
-    DIE("bad complex component type");
-  }
-#else
-  res = std::sqrt(x);
-#endif
-  return res;
-}
-
-template <typename HostT>
-static std::complex<HostT> CAcos(const std::complex<HostT> &x) {
-  std::complex<HostT> res;
-#ifdef _AIX
-  // On AIX, the implementation of cacos[f] and std::acos is 
diff erent,
-  // use cacos[f] in folding.
-  if constexpr (std::is_same_v<HostT, float>) {
-    float _Complex r{cacosf(CppToC<float _Complex, float>(x))};
-    res = CToCpp<float, float _Complex>(r);
-  } else if constexpr (std::is_same_v<HostT, double>) {
-    double _Complex r{cacos(CppToC<double _Complex, double>(x))};
-    res = CToCpp<double, double _Complex>(r);
-  } else {
-    DIE("bad complex component type");
-  }
-#else
-  res = std::acos(x);
-#endif
-  return res;
-}
-
 template <typename HostT>
 struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
-  using F = FuncPointer<std::complex<HostT>, const std::complex<HostT> &>;
-  using F2 = FuncPointer<std::complex<HostT>, const std::complex<HostT> &,
-      const std::complex<HostT> &>;
-  using F2A = FuncPointer<std::complex<HostT>, const HostT &,
-      const std::complex<HostT> &>;
-  using F2B = FuncPointer<std::complex<HostT>, const std::complex<HostT> &,
-      const HostT &>;
+  COMPLEX_SIGNATURES(HostT)
   static constexpr HostRuntimeFunction table[]{
-      FolderFactory<F, F{CAcos}>::Create("acos"),
+      FolderFactory<F, F{std::acos}>::Create("acos"),
       FolderFactory<F, F{std::acosh}>::Create("acosh"),
       FolderFactory<F, F{std::asin}>::Create("asin"),
       FolderFactory<F, F{std::asinh}>::Create("asinh"),
@@ -373,13 +309,129 @@ struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
       FolderFactory<F2B, F2B{StdPowF2B}>::Create("pow"),
       FolderFactory<F, F{std::sin}>::Create("sin"),
       FolderFactory<F, F{std::sinh}>::Create("sinh"),
-      FolderFactory<F, F{CSqrt}>::Create("sqrt"),
+      FolderFactory<F, F{std::sqrt}>::Create("sqrt"),
       FolderFactory<F, F{std::tan}>::Create("tan"),
       FolderFactory<F, F{std::tanh}>::Create("tanh"),
   };
   static constexpr HostRuntimeMap map{table};
   static_assert(map.Verify(), "map must be sorted");
 };
+#else
+// On AIX, call libm routines to preserve consistent value between
+// runtime and compile time evaluation.
+#ifdef __clang_major__
+#pragma clang diagnostic ignored "-Wc99-extensions"
+#endif
+
+extern "C" {
+float _Complex cacosf(float _Complex);
+double _Complex cacos(double _Complex);
+float _Complex cacoshf(float _Complex);
+double _Complex cacosh(double _Complex);
+float _Complex casinf(float _Complex);
+double _Complex casin(double _Complex);
+float _Complex casinhf(float _Complex);
+double _Complex casinh(double _Complex);
+float _Complex catanf(float _Complex);
+double _Complex catan(double _Complex);
+float _Complex catanhf(float _Complex);
+double _Complex catanh(double _Complex);
+float _Complex ccosf(float _Complex);
+double _Complex ccos(double _Complex);
+float _Complex ccoshf(float _Complex);
+double _Complex ccosh(double _Complex);
+float _Complex cexpf(float _Complex);
+double _Complex cexp(double _Complex);
+float _Complex clogf(float _Complex);
+double _Complex __clog(double _Complex);
+float _Complex cpowf(float _Complex, float _Complex);
+double _Complex cpow(double _Complex, double _Complex);
+float _Complex csinf(float _Complex);
+double _Complex csin(double _Complex);
+float _Complex csinhf(float _Complex);
+double _Complex csinh(double _Complex);
+float _Complex csqrtf(float _Complex);
+double _Complex csqrt(double _Complex);
+float _Complex ctanf(float _Complex);
+double _Complex ctan(double _Complex);
+float _Complex ctanhf(float _Complex);
+double _Complex ctanh(double _Complex);
+}
+
+template <typename T> struct ToStdComplex {
+  using Type = T;
+  using AType = Type;
+};
+template <> struct ToStdComplex<float _Complex> {
+  using Type = std::complex<float>;
+  using AType = const Type &;
+};
+template <> struct ToStdComplex<double _Complex> {
+  using Type = std::complex<double>;
+  using AType = const Type &;
+};
+
+template <typename F, F func> struct CComplexFunc {};
+template <typename R, typename... A, FuncPointer<R, A...> func>
+struct CComplexFunc<FuncPointer<R, A...>, func> {
+  static typename ToStdComplex<R>::Type wrapper(
+      typename ToStdComplex<A>::AType... args) {
+    R res{func(*reinterpret_cast<const A *>(&args)...)};
+    return *reinterpret_cast<typename ToStdComplex<R>::Type *>(&res);
+  }
+};
+#define C_COMPLEX_FUNC(func) CComplexFunc<decltype(&func), &func>::wrapper
+
+template <>
+struct HostRuntimeLibrary<std::complex<float>, LibraryVersion::Libm> {
+  COMPLEX_SIGNATURES(float)
+  static constexpr HostRuntimeFunction table[]{
+      FolderFactory<F, C_COMPLEX_FUNC(cacosf)>::Create("acos"),
+      FolderFactory<F, C_COMPLEX_FUNC(cacoshf)>::Create("acosh"),
+      FolderFactory<F, C_COMPLEX_FUNC(casinf)>::Create("asin"),
+      FolderFactory<F, C_COMPLEX_FUNC(casinhf)>::Create("asinh"),
+      FolderFactory<F, C_COMPLEX_FUNC(catanf)>::Create("atan"),
+      FolderFactory<F, C_COMPLEX_FUNC(catanhf)>::Create("atanh"),
+      FolderFactory<F, C_COMPLEX_FUNC(ccosf)>::Create("cos"),
+      FolderFactory<F, C_COMPLEX_FUNC(ccoshf)>::Create("cosh"),
+      FolderFactory<F, C_COMPLEX_FUNC(cexpf)>::Create("exp"),
+      FolderFactory<F, C_COMPLEX_FUNC(clogf)>::Create("log"),
+      FolderFactory<F2, C_COMPLEX_FUNC(cpowf)>::Create("pow"),
+      FolderFactory<F, C_COMPLEX_FUNC(csinf)>::Create("sin"),
+      FolderFactory<F, C_COMPLEX_FUNC(csinhf)>::Create("sinh"),
+      FolderFactory<F, C_COMPLEX_FUNC(csqrtf)>::Create("sqrt"),
+      FolderFactory<F, C_COMPLEX_FUNC(ctanf)>::Create("tan"),
+      FolderFactory<F, C_COMPLEX_FUNC(ctanhf)>::Create("tanh"),
+  };
+  static constexpr HostRuntimeMap map{table};
+  static_assert(map.Verify(), "map must be sorted");
+};
+template <>
+struct HostRuntimeLibrary<std::complex<double>, LibraryVersion::Libm> {
+  COMPLEX_SIGNATURES(double)
+  static constexpr HostRuntimeFunction table[]{
+      FolderFactory<F, C_COMPLEX_FUNC(cacos)>::Create("acos"),
+      FolderFactory<F, C_COMPLEX_FUNC(cacosh)>::Create("acosh"),
+      FolderFactory<F, C_COMPLEX_FUNC(casin)>::Create("asin"),
+      FolderFactory<F, C_COMPLEX_FUNC(casinh)>::Create("asinh"),
+      FolderFactory<F, C_COMPLEX_FUNC(catan)>::Create("atan"),
+      FolderFactory<F, C_COMPLEX_FUNC(catanh)>::Create("atanh"),
+      FolderFactory<F, C_COMPLEX_FUNC(ccos)>::Create("cos"),
+      FolderFactory<F, C_COMPLEX_FUNC(ccosh)>::Create("cosh"),
+      FolderFactory<F, C_COMPLEX_FUNC(cexp)>::Create("exp"),
+      FolderFactory<F, C_COMPLEX_FUNC(__clog)>::Create("log"),
+      FolderFactory<F2, C_COMPLEX_FUNC(cpow)>::Create("pow"),
+      FolderFactory<F, C_COMPLEX_FUNC(csin)>::Create("sin"),
+      FolderFactory<F, C_COMPLEX_FUNC(csinh)>::Create("sinh"),
+      FolderFactory<F, C_COMPLEX_FUNC(csqrt)>::Create("sqrt"),
+      FolderFactory<F, C_COMPLEX_FUNC(ctan)>::Create("tan"),
+      FolderFactory<F, C_COMPLEX_FUNC(ctanh)>::Create("tanh"),
+  };
+  static constexpr HostRuntimeMap map{table};
+  static_assert(map.Verify(), "map must be sorted");
+};
+#endif // _AIX
+
 // Note regarding cmath:
 //  - cmath does not have modulo and erfc_scaled equivalent
 //  - C++17 defined standard Bessel math functions std::cyl_bessel_j


        


More information about the flang-commits mailing list