[flang-commits] [flang] [flang][AIX] Handle more trig functions with complex argument to have consistent results in folding (PR #124203)
Kelvin Li via flang-commits
flang-commits at lists.llvm.org
Thu Jan 23 14:51:00 PST 2025
https://github.com/kkwli created https://github.com/llvm/llvm-project/pull/124203
This patch is to extend https://github.com/llvm/llvm-project/commit/71d4f343f52756ca086d02151662e68633a0db52 to include all trig functions that allow arguments of complex type. On AIX, the `libm` routines are called in compile time folding instead of the STL routines.
>From ec087b0d95c4ce024ab1b594c2593c555bc14e36 Mon Sep 17 00:00:00 2001
From: Kelvin Li <kli at ca.ibm.com>
Date: Fri, 3 Jan 2025 18:15:26 -0500
Subject: [PATCH 1/2] [flang] Handle more trig functions with complex argument
to have consistent results in folding
---
flang/lib/Evaluate/intrinsics-library.cpp | 178 ++++++++++++++++------
1 file changed, 130 insertions(+), 48 deletions(-)
diff --git a/flang/lib/Evaluate/intrinsics-library.cpp b/flang/lib/Evaluate/intrinsics-library.cpp
index c1b270f518c0e0..c47004b289e00c 100644
--- a/flang/lib/Evaluate/intrinsics-library.cpp
+++ b/flang/lib/Evaluate/intrinsics-library.cpp
@@ -278,6 +278,24 @@ static std::complex<HostT> StdPowF2B(
return std::pow(x, y);
}
+enum trigFunc {
+ Cacos,
+ Cacosh,
+ Casin,
+ Casinh,
+ Catan,
+ Catanh,
+ Ccos,
+ Ccosh,
+ Cexp,
+ Clog,
+ Csin,
+ Csinh,
+ Csqrt,
+ Ctan,
+ Ctanh
+};
+
#ifdef _AIX
#ifdef __clang_major__
#pragma clang diagnostic ignored "-Wc99-extensions"
@@ -286,8 +304,34 @@ static std::complex<HostT> StdPowF2B(
extern "C" {
float _Complex cacosf(float _Complex);
double _Complex cacos(double _Complex);
+float _Complex cacoshf(float _Complex);
+double _Complex cacosh(double _Complex);
+float _Complex casinf(float _Complex);
+double _Complex casin(double _Complex);
+float _Complex casinhf(float _Complex);
+double _Complex casinh(double _Complex);
+float _Complex catanf(float _Complex);
+double _Complex catan(double _Complex);
+float _Complex catanhf(float _Complex);
+double _Complex catanh(double _Complex);
+float _Complex ccosf(float _Complex);
+double _Complex ccos(double _Complex);
+float _Complex ccoshf(float _Complex);
+double _Complex ccosh(double _Complex);
+float _Complex cexpf(float _Complex);
+double _Complex cexp(double _Complex);
+float _Complex clogf(float _Complex);
+double _Complex __clog(double _Complex);
+float _Complex csinf(float _Complex);
+double _Complex csin(double _Complex);
+float _Complex csinhf(float _Complex);
+double _Complex csinh(double _Complex);
float _Complex csqrtf(float _Complex);
double _Complex csqrt(double _Complex);
+float _Complex ctanf(float _Complex);
+double _Complex ctan(double _Complex);
+float _Complex ctanhf(float _Complex);
+double _Complex ctanh(double _Complex);
}
enum CRI { Real, Imag };
@@ -304,49 +348,87 @@ template <typename T, typename TA> static std::complex<T> CToCpp(const TA &x) {
TA &z{const_cast<TA &>(x)};
return std::complex<T>(reIm<T, TA>(z, CRI::Real), reIm<T, TA>(z, CRI::Imag));
}
-#endif
-template <typename HostT>
-static std::complex<HostT> CSqrt(const std::complex<HostT> &x) {
- std::complex<HostT> res;
-#ifdef _AIX
- // On AIX, the implementation of csqrt[f] and std::sqrt is different,
- // use csqrt[f] in folding.
+using FTypeCmplxFlt = _Complex float (*)(_Complex float);
+using FTypeCmplxDble = _Complex double (*)(_Complex double);
+template <typename T>
+using FTypeStdCmplx = std::complex<T> (*)(const std::complex<T>&);
+
+std::map<trigFunc, std::tuple<FTypeCmplxFlt, FTypeCmplxDble>> mapLibmTrigFunc{
+ {Cacos, {&cacosf, &cacos}}, {Cacosh, {&cacoshf, &cacosh}},
+ {Casin, {&casinf, &casin}}, {Casinh, {&casinhf, &casinh}},
+ {Catan, {&catanf, &catan}}, {Catanh, {&catanhf, &catanh}},
+ {Ccos, {&ccosf, &ccos}}, {Ccosh, {&ccoshf, &ccosh}},
+ {Cexp, {&cexpf, &cexp}}, {Clog, {&clogf, &__clog}},
+ {Csin, {&csinf, &csin}}, {Csinh, {&csinhf, &csinh}},
+ {Csqrt, {&csqrtf, &csqrt}}, {Ctan, {&ctanf, &ctan}},
+ {Ctanh, {&ctanhf, &ctanh}}};
+
+template <trigFunc TF, typename HostT>
+std::complex<HostT> LibmTrigFunc(const std::complex<HostT> &x) {
if constexpr (std::is_same_v<HostT, float>) {
- float _Complex r{csqrtf(CppToC<float _Complex, float>(x))};
- res = CToCpp<float, float _Complex>(r);
+ float _Complex r{
+ std::get<FTypeCmplxFlt>(mapLibmTrigFunc[TF])(CppToC<float _Complex, float>(x))};
+ return CToCpp<float, float _Complex>(r);
} else if constexpr (std::is_same_v<HostT, double>) {
- double _Complex r{csqrt(CppToC<double _Complex, double>(x))};
- res = CToCpp<double, double _Complex>(r);
- } else {
- DIE("bad complex component type");
+ double _Complex r{
+ std::get<FTypeCmplxDble>(mapLibmTrigFunc[TF])(CppToC<double _Complex, double>(x))};
+ return CToCpp<double, double _Complex>(r);
}
-#else
- res = std::sqrt(x);
+ DIE("bad complex component type");
+}
#endif
- return res;
+
+template <trigFunc TF, typename HostT>
+std::complex<HostT> StdTrigFunc(const std::complex<HostT> &x) {
+ if constexpr (TF == Cacos) {
+ return std::acos(x);
+ } else if constexpr (TF == Cacosh) {
+ return std::acosh(x);
+ } else if constexpr (TF == Casin) {
+ return std::asin(x);
+ } else if constexpr (TF == Casinh) {
+ return std::asinh(x);
+ } else if constexpr (TF == Catan) {
+ return std::atan(x);
+ } else if constexpr (TF == Catanh) {
+ return std::atanh(x);
+ } else if constexpr (TF == Ccos) {
+ return std::cos(x);
+ } else if constexpr (TF == Ccosh) {
+ return std::cosh(x);
+ } else if constexpr (TF == Cexp) {
+ return std::exp(x);
+ } else if constexpr (TF == Clog) {
+ return std::log(x);
+ } else if constexpr (TF == Csin) {
+ return std::sin(x);
+ } else if constexpr (TF == Csinh) {
+ return std::sinh(x);
+ } else if constexpr (TF == Csqrt) {
+ return std::sqrt(x);
+ } else if constexpr (TF == Ctan) {
+ return std::tan(x);
+ } else if constexpr (TF == Ctanh) {
+ return std::tanh(x);
+ }
+ DIE("unknown function");
}
-template <typename HostT>
-static std::complex<HostT> CAcos(const std::complex<HostT> &x) {
- std::complex<HostT> res;
+template <trigFunc TF> struct X {
+ template <typename HostT>
+ static std::complex<HostT> f(const std::complex<HostT> &x) {
+ std::complex<HostT> res;
#ifdef _AIX
- // On AIX, the implementation of cacos[f] and std::acos is different,
- // use cacos[f] in folding.
- if constexpr (std::is_same_v<HostT, float>) {
- float _Complex r{cacosf(CppToC<float _Complex, float>(x))};
- res = CToCpp<float, float _Complex>(r);
- } else if constexpr (std::is_same_v<HostT, double>) {
- double _Complex r{cacos(CppToC<double _Complex, double>(x))};
- res = CToCpp<double, double _Complex>(r);
- } else {
- DIE("bad complex component type");
- }
+ // On AIX, the implementation in libm is different from that of std::
+ // routines, use the libm routines here in folding for consistent results.
+ res = LibmTrigFunc<TF>(x);
#else
- res = std::acos(x);
+ res = StdTrigFunc<TF, HostT>(x);
#endif
- return res;
-}
+ return res;
+ }
+};
template <typename HostT>
struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
@@ -358,24 +440,24 @@ struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
using F2B = FuncPointer<std::complex<HostT>, const std::complex<HostT> &,
const HostT &>;
static constexpr HostRuntimeFunction table[]{
- FolderFactory<F, F{CAcos}>::Create("acos"),
- FolderFactory<F, F{std::acosh}>::Create("acosh"),
- FolderFactory<F, F{std::asin}>::Create("asin"),
- FolderFactory<F, F{std::asinh}>::Create("asinh"),
- FolderFactory<F, F{std::atan}>::Create("atan"),
- FolderFactory<F, F{std::atanh}>::Create("atanh"),
- FolderFactory<F, F{std::cos}>::Create("cos"),
- FolderFactory<F, F{std::cosh}>::Create("cosh"),
- FolderFactory<F, F{std::exp}>::Create("exp"),
- FolderFactory<F, F{std::log}>::Create("log"),
+ FolderFactory<F, F{X<Cacos>::f}>::Create("acos"),
+ FolderFactory<F, F{X<Cacosh>::f}>::Create("acosh"),
+ FolderFactory<F, F{X<Casin>::f}>::Create("asin"),
+ FolderFactory<F, F{X<Casinh>::f}>::Create("asinh"),
+ FolderFactory<F, F{X<Catan>::f}>::Create("atan"),
+ FolderFactory<F, F{X<Catanh>::f}>::Create("atanh"),
+ FolderFactory<F, F{X<Ccos>::f}>::Create("cos"),
+ FolderFactory<F, F{X<Ccosh>::f}>::Create("cosh"),
+ FolderFactory<F, F{X<Cexp>::f}>::Create("exp"),
+ FolderFactory<F, F{X<Clog>::f}>::Create("log"),
FolderFactory<F2, F2{StdPowF2}>::Create("pow"),
FolderFactory<F2A, F2A{StdPowF2A}>::Create("pow"),
FolderFactory<F2B, F2B{StdPowF2B}>::Create("pow"),
- FolderFactory<F, F{std::sin}>::Create("sin"),
- FolderFactory<F, F{std::sinh}>::Create("sinh"),
- FolderFactory<F, F{CSqrt}>::Create("sqrt"),
- FolderFactory<F, F{std::tan}>::Create("tan"),
- FolderFactory<F, F{std::tanh}>::Create("tanh"),
+ FolderFactory<F, F{X<Csin>::f}>::Create("sin"),
+ FolderFactory<F, F{X<Csinh>::f}>::Create("sinh"),
+ FolderFactory<F, F{X<Csqrt>::f}>::Create("sqrt"),
+ FolderFactory<F, F{X<Ctan>::f}>::Create("tan"),
+ FolderFactory<F, F{X<Ctanh>::f}>::Create("tanh"),
};
static constexpr HostRuntimeMap map{table};
static_assert(map.Verify(), "map must be sorted");
>From e032bfd3b62f1f1c241fc0822d1e2185b1f34d2d Mon Sep 17 00:00:00 2001
From: Kelvin Li <kli at ca.ibm.com>
Date: Thu, 23 Jan 2025 17:39:18 -0500
Subject: [PATCH 2/2] add pow
---
flang/lib/Evaluate/intrinsics-library.cpp | 80 ++++++++++++++++++-----
1 file changed, 62 insertions(+), 18 deletions(-)
diff --git a/flang/lib/Evaluate/intrinsics-library.cpp b/flang/lib/Evaluate/intrinsics-library.cpp
index c47004b289e00c..60bb2785c725fd 100644
--- a/flang/lib/Evaluate/intrinsics-library.cpp
+++ b/flang/lib/Evaluate/intrinsics-library.cpp
@@ -260,24 +260,6 @@ struct HostRuntimeLibrary<HostT, LibraryVersion::Libm> {
static_assert(map.Verify(), "map must be sorted");
};
-// Helpers to map complex std::pow whose resolution in F2{std::pow} is
-// ambiguous as of clang++ 20.
-template <typename HostT>
-static std::complex<HostT> StdPowF2(
- const std::complex<HostT> &x, const std::complex<HostT> &y) {
- return std::pow(x, y);
-}
-template <typename HostT>
-static std::complex<HostT> StdPowF2A(
- const HostT &x, const std::complex<HostT> &y) {
- return std::pow(x, y);
-}
-template <typename HostT>
-static std::complex<HostT> StdPowF2B(
- const std::complex<HostT> &x, const HostT &y) {
- return std::pow(x, y);
-}
-
enum trigFunc {
Cacos,
Cacosh,
@@ -322,6 +304,8 @@ float _Complex cexpf(float _Complex);
double _Complex cexp(double _Complex);
float _Complex clogf(float _Complex);
double _Complex __clog(double _Complex);
+float _Complex cpowf(float _Complex, float _Complex);
+double _Complex cpow(double _Complex, double _Complex);
float _Complex csinf(float _Complex);
double _Complex csin(double _Complex);
float _Complex csinhf(float _Complex);
@@ -430,6 +414,66 @@ template <trigFunc TF> struct X {
}
};
+// Helpers to map complex std::pow whose resolution in F2{std::pow} is
+// ambiguous as of clang++ 20.
+template <typename HostT>
+static std::complex<HostT> StdPowF2(const std::complex<HostT> &x,
+ const std::complex<HostT> &y) {
+#ifdef _AIX
+ if constexpr (std::is_same_v<HostT, float>) {
+ float _Complex r{cpowf(CppToC<float _Complex, float>(x),
+ CppToC<float _Complex, float>(y))};
+ return CToCpp<float, float _Complex>(r);
+ } else if constexpr (std::is_same_v<HostT, double>) {
+ double _Complex r{cpow(CppToC<double _Complex, double>(x),
+ CppToC<double _Complex, double>(y))};
+ return CToCpp<double, double _Complex>(r);
+ }
+#else
+ return std::pow(x, y);
+#endif
+}
+
+template <typename HostT>
+static std::complex<HostT> StdPowF2A(const HostT &x,
+ const std::complex<HostT> &y) {
+#ifdef _AIX
+ constexpr HostT zero{0.0};
+ std::complex<HostT> z(x, zero);
+ if constexpr (std::is_same_v<HostT, float>) {
+ float _Complex r{cpowf(CppToC<float _Complex, float>(z),
+ CppToC<float _Complex, float>(y))};
+ return CToCpp<float, float _Complex>(r);
+ } else if constexpr (std::is_same_v<HostT, double>) {
+ double _Complex r{cpow(CppToC<double _Complex, double>(z),
+ CppToC<double _Complex, double>(y))};
+ return CToCpp<double, double _Complex>(r);
+ }
+#else
+ return std::pow(x, y);
+#endif
+}
+
+template <typename HostT>
+static std::complex<HostT> StdPowF2B(const std::complex<HostT> &x,
+ const HostT &y) {
+#ifdef _AIX
+ constexpr HostT zero{0.0};
+ std::complex<HostT> z(y, zero);
+ if constexpr (std::is_same_v<HostT, float>) {
+ float _Complex r{cpowf(CppToC<float _Complex, float>(x),
+ CppToC<float _Complex, float>(z))};
+ return CToCpp<float, float _Complex>(r);
+ } else if constexpr (std::is_same_v<HostT, double>) {
+ double _Complex r{cpow(CppToC<double _Complex, double>(x),
+ CppToC<double _Complex, double>(z))};
+ return CToCpp<double, double _Complex>(r);
+ }
+#else
+ return std::pow(x, y);
+#endif
+}
+
template <typename HostT>
struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
using F = FuncPointer<std::complex<HostT>, const std::complex<HostT> &>;
More information about the flang-commits
mailing list