[flang-commits] [flang] a8d4335 - [flang][AIX] Handle more trig functions with complex argument to have consistent results in folding (#124203)
via flang-commits
flang-commits at lists.llvm.org
Wed Jan 29 17:26:43 PST 2025
Author: Kelvin Li
Date: 2025-01-29T20:26:39-05:00
New Revision: a8d4335ee08b18068eb913528b963590f510d0b4
URL: https://github.com/llvm/llvm-project/commit/a8d4335ee08b18068eb913528b963590f510d0b4
DIFF: https://github.com/llvm/llvm-project/commit/a8d4335ee08b18068eb913528b963590f510d0b4.diff
LOG: [flang][AIX] Handle more trig functions with complex argument to have consistent results in folding (#124203)
This patch extends 71d4f34 to all trig functions that take complex
arguments. On AIX, the `libm` routines are called in compile time
folding instead of the STL routines.
Added:
Modified:
flang/lib/Evaluate/intrinsics-library.cpp
Removed:
################################################################################
diff --git a/flang/lib/Evaluate/intrinsics-library.cpp b/flang/lib/Evaluate/intrinsics-library.cpp
index c1b270f518c0e0..d2c1be65dca448 100644
--- a/flang/lib/Evaluate/intrinsics-library.cpp
+++ b/flang/lib/Evaluate/intrinsics-library.cpp
@@ -260,6 +260,16 @@ struct HostRuntimeLibrary<HostT, LibraryVersion::Libm> {
static_assert(map.Verify(), "map must be sorted");
};
+#define COMPLEX_SIGNATURES(HOST_T) \
+ using F = FuncPointer<std::complex<HOST_T>, const std::complex<HOST_T> &>; \
+ using F2 = FuncPointer<std::complex<HOST_T>, const std::complex<HOST_T> &, \
+ const std::complex<HOST_T> &>; \
+ using F2A = FuncPointer<std::complex<HOST_T>, const HOST_T &, \
+ const std::complex<HOST_T> &>; \
+ using F2B = FuncPointer<std::complex<HOST_T>, const std::complex<HOST_T> &, \
+ const HOST_T &>;
+
+#ifndef _AIX
// Helpers to map complex std::pow whose resolution in F2{std::pow} is
// ambiguous as of clang++ 20.
template <typename HostT>
@@ -267,98 +277,24 @@ static std::complex<HostT> StdPowF2(
const std::complex<HostT> &x, const std::complex<HostT> &y) {
return std::pow(x, y);
}
+
template <typename HostT>
static std::complex<HostT> StdPowF2A(
const HostT &x, const std::complex<HostT> &y) {
return std::pow(x, y);
}
+
template <typename HostT>
static std::complex<HostT> StdPowF2B(
const std::complex<HostT> &x, const HostT &y) {
return std::pow(x, y);
}
-#ifdef _AIX
-#ifdef __clang_major__
-#pragma clang diagnostic ignored "-Wc99-extensions"
-#endif
-
-extern "C" {
-float _Complex cacosf(float _Complex);
-double _Complex cacos(double _Complex);
-float _Complex csqrtf(float _Complex);
-double _Complex csqrt(double _Complex);
-}
-
-enum CRI { Real, Imag };
-template <typename TR, typename TA> static TR &reIm(TA &x, CRI n) {
- return reinterpret_cast<TR(&)[2]>(x)[n];
-}
-template <typename TR, typename T> static TR CppToC(const std::complex<T> &x) {
- TR r;
- reIm<T, TR>(r, CRI::Real) = x.real();
- reIm<T, TR>(r, CRI::Imag) = x.imag();
- return r;
-}
-template <typename T, typename TA> static std::complex<T> CToCpp(const TA &x) {
- TA &z{const_cast<TA &>(x)};
- return std::complex<T>(reIm<T, TA>(z, CRI::Real), reIm<T, TA>(z, CRI::Imag));
-}
-#endif
-
-template <typename HostT>
-static std::complex<HostT> CSqrt(const std::complex<HostT> &x) {
- std::complex<HostT> res;
-#ifdef _AIX
- // On AIX, the implementation of csqrt[f] and std::sqrt is
diff erent,
- // use csqrt[f] in folding.
- if constexpr (std::is_same_v<HostT, float>) {
- float _Complex r{csqrtf(CppToC<float _Complex, float>(x))};
- res = CToCpp<float, float _Complex>(r);
- } else if constexpr (std::is_same_v<HostT, double>) {
- double _Complex r{csqrt(CppToC<double _Complex, double>(x))};
- res = CToCpp<double, double _Complex>(r);
- } else {
- DIE("bad complex component type");
- }
-#else
- res = std::sqrt(x);
-#endif
- return res;
-}
-
-template <typename HostT>
-static std::complex<HostT> CAcos(const std::complex<HostT> &x) {
- std::complex<HostT> res;
-#ifdef _AIX
- // On AIX, the implementation of cacos[f] and std::acos is
diff erent,
- // use cacos[f] in folding.
- if constexpr (std::is_same_v<HostT, float>) {
- float _Complex r{cacosf(CppToC<float _Complex, float>(x))};
- res = CToCpp<float, float _Complex>(r);
- } else if constexpr (std::is_same_v<HostT, double>) {
- double _Complex r{cacos(CppToC<double _Complex, double>(x))};
- res = CToCpp<double, double _Complex>(r);
- } else {
- DIE("bad complex component type");
- }
-#else
- res = std::acos(x);
-#endif
- return res;
-}
-
template <typename HostT>
struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
- using F = FuncPointer<std::complex<HostT>, const std::complex<HostT> &>;
- using F2 = FuncPointer<std::complex<HostT>, const std::complex<HostT> &,
- const std::complex<HostT> &>;
- using F2A = FuncPointer<std::complex<HostT>, const HostT &,
- const std::complex<HostT> &>;
- using F2B = FuncPointer<std::complex<HostT>, const std::complex<HostT> &,
- const HostT &>;
+ COMPLEX_SIGNATURES(HostT)
static constexpr HostRuntimeFunction table[]{
- FolderFactory<F, F{CAcos}>::Create("acos"),
+ FolderFactory<F, F{std::acos}>::Create("acos"),
FolderFactory<F, F{std::acosh}>::Create("acosh"),
FolderFactory<F, F{std::asin}>::Create("asin"),
FolderFactory<F, F{std::asinh}>::Create("asinh"),
@@ -373,13 +309,129 @@ struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
FolderFactory<F2B, F2B{StdPowF2B}>::Create("pow"),
FolderFactory<F, F{std::sin}>::Create("sin"),
FolderFactory<F, F{std::sinh}>::Create("sinh"),
- FolderFactory<F, F{CSqrt}>::Create("sqrt"),
+ FolderFactory<F, F{std::sqrt}>::Create("sqrt"),
FolderFactory<F, F{std::tan}>::Create("tan"),
FolderFactory<F, F{std::tanh}>::Create("tanh"),
};
static constexpr HostRuntimeMap map{table};
static_assert(map.Verify(), "map must be sorted");
};
+#else
+// On AIX, call libm routines to preserve consistent value between
+// runtime and compile time evaluation.
+#ifdef __clang_major__
+#pragma clang diagnostic ignored "-Wc99-extensions"
+#endif
+
+extern "C" {
+float _Complex cacosf(float _Complex);
+double _Complex cacos(double _Complex);
+float _Complex cacoshf(float _Complex);
+double _Complex cacosh(double _Complex);
+float _Complex casinf(float _Complex);
+double _Complex casin(double _Complex);
+float _Complex casinhf(float _Complex);
+double _Complex casinh(double _Complex);
+float _Complex catanf(float _Complex);
+double _Complex catan(double _Complex);
+float _Complex catanhf(float _Complex);
+double _Complex catanh(double _Complex);
+float _Complex ccosf(float _Complex);
+double _Complex ccos(double _Complex);
+float _Complex ccoshf(float _Complex);
+double _Complex ccosh(double _Complex);
+float _Complex cexpf(float _Complex);
+double _Complex cexp(double _Complex);
+float _Complex clogf(float _Complex);
+double _Complex __clog(double _Complex);
+float _Complex cpowf(float _Complex, float _Complex);
+double _Complex cpow(double _Complex, double _Complex);
+float _Complex csinf(float _Complex);
+double _Complex csin(double _Complex);
+float _Complex csinhf(float _Complex);
+double _Complex csinh(double _Complex);
+float _Complex csqrtf(float _Complex);
+double _Complex csqrt(double _Complex);
+float _Complex ctanf(float _Complex);
+double _Complex ctan(double _Complex);
+float _Complex ctanhf(float _Complex);
+double _Complex ctanh(double _Complex);
+}
+
+template <typename T> struct ToStdComplex {
+ using Type = T;
+ using AType = Type;
+};
+template <> struct ToStdComplex<float _Complex> {
+ using Type = std::complex<float>;
+ using AType = const Type &;
+};
+template <> struct ToStdComplex<double _Complex> {
+ using Type = std::complex<double>;
+ using AType = const Type &;
+};
+
+template <typename F, F func> struct CComplexFunc {};
+template <typename R, typename... A, FuncPointer<R, A...> func>
+struct CComplexFunc<FuncPointer<R, A...>, func> {
+ static typename ToStdComplex<R>::Type wrapper(
+ typename ToStdComplex<A>::AType... args) {
+ R res{func(*reinterpret_cast<const A *>(&args)...)};
+ return *reinterpret_cast<typename ToStdComplex<R>::Type *>(&res);
+ }
+};
+#define C_COMPLEX_FUNC(func) CComplexFunc<decltype(&func), &func>::wrapper
+
+template <>
+struct HostRuntimeLibrary<std::complex<float>, LibraryVersion::Libm> {
+ COMPLEX_SIGNATURES(float)
+ static constexpr HostRuntimeFunction table[]{
+ FolderFactory<F, C_COMPLEX_FUNC(cacosf)>::Create("acos"),
+ FolderFactory<F, C_COMPLEX_FUNC(cacoshf)>::Create("acosh"),
+ FolderFactory<F, C_COMPLEX_FUNC(casinf)>::Create("asin"),
+ FolderFactory<F, C_COMPLEX_FUNC(casinhf)>::Create("asinh"),
+ FolderFactory<F, C_COMPLEX_FUNC(catanf)>::Create("atan"),
+ FolderFactory<F, C_COMPLEX_FUNC(catanhf)>::Create("atanh"),
+ FolderFactory<F, C_COMPLEX_FUNC(ccosf)>::Create("cos"),
+ FolderFactory<F, C_COMPLEX_FUNC(ccoshf)>::Create("cosh"),
+ FolderFactory<F, C_COMPLEX_FUNC(cexpf)>::Create("exp"),
+ FolderFactory<F, C_COMPLEX_FUNC(clogf)>::Create("log"),
+ FolderFactory<F2, C_COMPLEX_FUNC(cpowf)>::Create("pow"),
+ FolderFactory<F, C_COMPLEX_FUNC(csinf)>::Create("sin"),
+ FolderFactory<F, C_COMPLEX_FUNC(csinhf)>::Create("sinh"),
+ FolderFactory<F, C_COMPLEX_FUNC(csqrtf)>::Create("sqrt"),
+ FolderFactory<F, C_COMPLEX_FUNC(ctanf)>::Create("tan"),
+ FolderFactory<F, C_COMPLEX_FUNC(ctanhf)>::Create("tanh"),
+ };
+ static constexpr HostRuntimeMap map{table};
+ static_assert(map.Verify(), "map must be sorted");
+};
+template <>
+struct HostRuntimeLibrary<std::complex<double>, LibraryVersion::Libm> {
+ COMPLEX_SIGNATURES(double)
+ static constexpr HostRuntimeFunction table[]{
+ FolderFactory<F, C_COMPLEX_FUNC(cacos)>::Create("acos"),
+ FolderFactory<F, C_COMPLEX_FUNC(cacosh)>::Create("acosh"),
+ FolderFactory<F, C_COMPLEX_FUNC(casin)>::Create("asin"),
+ FolderFactory<F, C_COMPLEX_FUNC(casinh)>::Create("asinh"),
+ FolderFactory<F, C_COMPLEX_FUNC(catan)>::Create("atan"),
+ FolderFactory<F, C_COMPLEX_FUNC(catanh)>::Create("atanh"),
+ FolderFactory<F, C_COMPLEX_FUNC(ccos)>::Create("cos"),
+ FolderFactory<F, C_COMPLEX_FUNC(ccosh)>::Create("cosh"),
+ FolderFactory<F, C_COMPLEX_FUNC(cexp)>::Create("exp"),
+ FolderFactory<F, C_COMPLEX_FUNC(__clog)>::Create("log"),
+ FolderFactory<F2, C_COMPLEX_FUNC(cpow)>::Create("pow"),
+ FolderFactory<F, C_COMPLEX_FUNC(csin)>::Create("sin"),
+ FolderFactory<F, C_COMPLEX_FUNC(csinh)>::Create("sinh"),
+ FolderFactory<F, C_COMPLEX_FUNC(csqrt)>::Create("sqrt"),
+ FolderFactory<F, C_COMPLEX_FUNC(ctan)>::Create("tan"),
+ FolderFactory<F, C_COMPLEX_FUNC(ctanh)>::Create("tanh"),
+ };
+ static constexpr HostRuntimeMap map{table};
+ static_assert(map.Verify(), "map must be sorted");
+};
+#endif // _AIX
+
// Note regarding cmath:
// - cmath does not have modulo and erfc_scaled equivalent
// - C++17 defined standard Bessel math functions std::cyl_bessel_j
More information about the flang-commits
mailing list