[libcxx-commits] [libcxx] [libc++] Constrain additional overloads of `pow` for `complex` harder (PR #110235)

A. Jiang via libcxx-commits libcxx-commits at lists.llvm.org
Sat Oct 26 23:51:33 PDT 2024


https://github.com/frederick-vs-ja updated https://github.com/llvm/llvm-project/pull/110235

>From b3c2eeaec8b67dfbb05840b53cccfcb562e38313 Mon Sep 17 00:00:00 2001
From: "A. Jiang" <de34 at live.cn>
Date: Fri, 27 Sep 2024 18:04:36 +0800
Subject: [PATCH 1/3] [libc++] Constrain additional overloads of `pow` for
 `complex` harder

---
 libcxx/include/complex                        |   6 +-
 .../complex.number/cmplx.over.pow.pass.cpp    | 106 ++++++++++++++++++
 2 files changed, 109 insertions(+), 3 deletions(-)
 create mode 100644 libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp

diff --git a/libcxx/include/complex b/libcxx/include/complex
index 4030d96b003d56..15e42800fbfa0a 100644
--- a/libcxx/include/complex
+++ b/libcxx/include/complex
@@ -1097,20 +1097,20 @@ inline _LIBCPP_HIDE_FROM_ABI complex<_Tp> pow(const complex<_Tp>& __x, const com
   return std::exp(__y * std::log(__x));
 }
 
-template <class _Tp, class _Up>
+template <class _Tp, class _Up, __enable_if_t<is_floating_point<_Tp>::value && is_floating_point<_Up>::value, int> = 0>
 inline _LIBCPP_HIDE_FROM_ABI complex<typename __promote<_Tp, _Up>::type>
 pow(const complex<_Tp>& __x, const complex<_Up>& __y) {
   typedef complex<typename __promote<_Tp, _Up>::type> result_type;
   return std::pow(result_type(__x), result_type(__y));
 }
 
-template <class _Tp, class _Up, __enable_if_t<is_arithmetic<_Up>::value, int> = 0>
+template <class _Tp, class _Up, __enable_if_t<is_floating_point<_Tp>::value && is_arithmetic<_Up>::value, int> = 0>
 inline _LIBCPP_HIDE_FROM_ABI complex<typename __promote<_Tp, _Up>::type> pow(const complex<_Tp>& __x, const _Up& __y) {
   typedef complex<typename __promote<_Tp, _Up>::type> result_type;
   return std::pow(result_type(__x), result_type(__y));
 }
 
-template <class _Tp, class _Up, __enable_if_t<is_arithmetic<_Tp>::value, int> = 0>
+template <class _Tp, class _Up, __enable_if_t<is_arithmetic<_Tp>::value && is_floating_point<_Up>::value, int> = 0>
 inline _LIBCPP_HIDE_FROM_ABI complex<typename __promote<_Tp, _Up>::type> pow(const _Tp& __x, const complex<_Up>& __y) {
   typedef complex<typename __promote<_Tp, _Up>::type> result_type;
   return std::pow(result_type(__x), result_type(__y));
diff --git a/libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp b/libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp
new file mode 100644
index 00000000000000..64e679fed7435c
--- /dev/null
+++ b/libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp
@@ -0,0 +1,106 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// <complex>
+
+//  template<class T, class U> complex<__promote<T, U>::type> pow(const complex<T>&, const U&);
+//  template<class T, class U> complex<__promote<T, U>::type> pow(const complex<T>&, const complex<U>&);
+//  template<class T, class U> complex<__promote<T, U>::type> pow(const T&, const complex<U>&);
+
+// Test that these additional overloads are free from catching std::complex<non-floating-point>,
+// which is expected by several 3rd party libraries, see https://github.com/llvm/llvm-project/issues/109858.
+
+#include <cassert>
+#include <cmath>
+#include <complex>
+#include <type_traits>
+
+#include "test_macros.h"
+
+namespace usr {
+struct usr_tag {};
+
+template <class T, class U>
+TEST_CONSTEXPR
+    typename std::enable_if<(std::is_same<T, usr_tag>::value && std::is_floating_point<U>::value) ||
+                                (std::is_floating_point<T>::value && std::is_same<U, usr_tag>::value),
+                            int>::type
+    pow(const T&, const std::complex<U>&) {
+  return std::is_same<T, usr_tag>::value ? 0 : 1;
+}
+
+template <class T, class U>
+TEST_CONSTEXPR
+    typename std::enable_if<(std::is_same<T, usr_tag>::value && std::is_floating_point<U>::value) ||
+                                (std::is_floating_point<T>::value && std::is_same<U, usr_tag>::value),
+                            int>::type
+    pow(const std::complex<T>&, const U&) {
+  return std::is_same<U, usr_tag>::value ? 2 : 3;
+}
+
+template <class T, class U>
+TEST_CONSTEXPR
+    typename std::enable_if<(std::is_same<T, usr_tag>::value && std::is_floating_point<U>::value) ||
+                                (std::is_floating_point<T>::value && std::is_same<U, usr_tag>::value),
+                            int>::type
+    pow(const std::complex<T>&, const std::complex<U>&) {
+  return std::is_same<T, usr_tag>::value ? 4 : 5;
+}
+} // namespace usr
+
+int main(int, char**) {
+  using std::pow;
+  using usr::pow;
+
+  TEST_CONSTEXPR usr::usr_tag tag;
+  TEST_CONSTEXPR_CXX14 const std::complex<usr::usr_tag> ctag;
+
+  assert(pow(tag, std::complex<float>(1.0f)) == 0);
+  assert(pow(std::complex<float>(1.0f), tag) == 2);
+  assert(pow(tag, std::complex<double>(1.0)) == 0);
+  assert(pow(std::complex<double>(1.0), tag) == 2);
+  assert(pow(tag, std::complex<long double>(1.0l)) == 0);
+  assert(pow(std::complex<long double>(1.0l), tag) == 2);
+
+  assert(pow(1.0f, ctag) == 1);
+  assert(pow(ctag, 1.0f) == 3);
+  assert(pow(1.0, ctag) == 1);
+  assert(pow(ctag, 1.0) == 3);
+  assert(pow(1.0l, ctag) == 1);
+  assert(pow(ctag, 1.0l) == 3);
+
+  assert(pow(ctag, std::complex<float>(1.0f)) == 4);
+  assert(pow(std::complex<float>(1.0f), ctag) == 5);
+  assert(pow(ctag, std::complex<double>(1.0)) == 4);
+  assert(pow(std::complex<double>(1.0), ctag) == 5);
+  assert(pow(ctag, std::complex<long double>(1.0l)) == 4);
+  assert(pow(std::complex<long double>(1.0l), ctag) == 5);
+
+#if TEST_STD_VER >= 11
+  static_assert(pow(tag, std::complex<float>(1.0f)) == 0, "");
+  static_assert(pow(std::complex<float>(1.0f), tag) == 2, "");
+  static_assert(pow(tag, std::complex<double>(1.0)) == 0, "");
+  static_assert(pow(std::complex<double>(1.0), tag) == 2, "");
+  static_assert(pow(tag, std::complex<long double>(1.0l)) == 0, "");
+  static_assert(pow(std::complex<long double>(1.0l), tag) == 2, "");
+
+  static_assert(pow(1.0f, ctag) == 1, "");
+  static_assert(pow(ctag, 1.0f) == 3, "");
+  static_assert(pow(1.0, ctag) == 1, "");
+  static_assert(pow(ctag, 1.0) == 3, "");
+  static_assert(pow(1.0l, ctag) == 1, "");
+  static_assert(pow(ctag, 1.0l) == 3, "");
+
+  static_assert(pow(ctag, std::complex<float>(1.0f)) == 4, "");
+  static_assert(pow(std::complex<float>(1.0f), ctag) == 5, "");
+  static_assert(pow(ctag, std::complex<double>(1.0)) == 4, "");
+  static_assert(pow(std::complex<double>(1.0), ctag) == 5, "");
+  static_assert(pow(ctag, std::complex<long double>(1.0l)) == 4, "");
+  static_assert(pow(std::complex<long double>(1.0l), ctag) == 5, "");
+#endif
+}

>From 84e7c7b2345eb5d9ca4a4d3711d6a69d5e88f26e Mon Sep 17 00:00:00 2001
From: "A. Jiang" <de34 at live.cn>
Date: Sat, 19 Oct 2024 22:38:46 +0800
Subject: [PATCH 2/3] Address @ldionne's review comments

- `return 0;` for freestanding environments.
- drop constexpr coverage
---
 .../complex.number/cmplx.over.pow.pass.cpp    | 54 ++++++-------------
 1 file changed, 15 insertions(+), 39 deletions(-)

diff --git a/libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp b/libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp
index 64e679fed7435c..bd38fb7edba091 100644
--- a/libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp
+++ b/libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp
@@ -26,29 +26,26 @@ namespace usr {
 struct usr_tag {};
 
 template <class T, class U>
-TEST_CONSTEXPR
-    typename std::enable_if<(std::is_same<T, usr_tag>::value && std::is_floating_point<U>::value) ||
-                                (std::is_floating_point<T>::value && std::is_same<U, usr_tag>::value),
-                            int>::type
-    pow(const T&, const std::complex<U>&) {
+typename std::enable_if<(std::is_same<T, usr_tag>::value && std::is_floating_point<U>::value) ||
+                            (std::is_floating_point<T>::value && std::is_same<U, usr_tag>::value),
+                        int>::type
+pow(const T&, const std::complex<U>&) {
   return std::is_same<T, usr_tag>::value ? 0 : 1;
 }
 
 template <class T, class U>
-TEST_CONSTEXPR
-    typename std::enable_if<(std::is_same<T, usr_tag>::value && std::is_floating_point<U>::value) ||
-                                (std::is_floating_point<T>::value && std::is_same<U, usr_tag>::value),
-                            int>::type
-    pow(const std::complex<T>&, const U&) {
+typename std::enable_if<(std::is_same<T, usr_tag>::value && std::is_floating_point<U>::value) ||
+                            (std::is_floating_point<T>::value && std::is_same<U, usr_tag>::value),
+                        int>::type
+pow(const std::complex<T>&, const U&) {
   return std::is_same<U, usr_tag>::value ? 2 : 3;
 }
 
 template <class T, class U>
-TEST_CONSTEXPR
-    typename std::enable_if<(std::is_same<T, usr_tag>::value && std::is_floating_point<U>::value) ||
-                                (std::is_floating_point<T>::value && std::is_same<U, usr_tag>::value),
-                            int>::type
-    pow(const std::complex<T>&, const std::complex<U>&) {
+typename std::enable_if<(std::is_same<T, usr_tag>::value && std::is_floating_point<U>::value) ||
+                            (std::is_floating_point<T>::value && std::is_same<U, usr_tag>::value),
+                        int>::type
+pow(const std::complex<T>&, const std::complex<U>&) {
   return std::is_same<T, usr_tag>::value ? 4 : 5;
 }
 } // namespace usr
@@ -57,8 +54,8 @@ int main(int, char**) {
   using std::pow;
   using usr::pow;
 
-  TEST_CONSTEXPR usr::usr_tag tag;
-  TEST_CONSTEXPR_CXX14 const std::complex<usr::usr_tag> ctag;
+  usr::usr_tag tag;
+  const std::complex<usr::usr_tag> ctag;
 
   assert(pow(tag, std::complex<float>(1.0f)) == 0);
   assert(pow(std::complex<float>(1.0f), tag) == 2);
@@ -81,26 +78,5 @@ int main(int, char**) {
   assert(pow(ctag, std::complex<long double>(1.0l)) == 4);
   assert(pow(std::complex<long double>(1.0l), ctag) == 5);
 
-#if TEST_STD_VER >= 11
-  static_assert(pow(tag, std::complex<float>(1.0f)) == 0, "");
-  static_assert(pow(std::complex<float>(1.0f), tag) == 2, "");
-  static_assert(pow(tag, std::complex<double>(1.0)) == 0, "");
-  static_assert(pow(std::complex<double>(1.0), tag) == 2, "");
-  static_assert(pow(tag, std::complex<long double>(1.0l)) == 0, "");
-  static_assert(pow(std::complex<long double>(1.0l), tag) == 2, "");
-
-  static_assert(pow(1.0f, ctag) == 1, "");
-  static_assert(pow(ctag, 1.0f) == 3, "");
-  static_assert(pow(1.0, ctag) == 1, "");
-  static_assert(pow(ctag, 1.0) == 3, "");
-  static_assert(pow(1.0l, ctag) == 1, "");
-  static_assert(pow(ctag, 1.0l) == 3, "");
-
-  static_assert(pow(ctag, std::complex<float>(1.0f)) == 4, "");
-  static_assert(pow(std::complex<float>(1.0f), ctag) == 5, "");
-  static_assert(pow(ctag, std::complex<double>(1.0)) == 4, "");
-  static_assert(pow(std::complex<double>(1.0), ctag) == 5, "");
-  static_assert(pow(ctag, std::complex<long double>(1.0l)) == 4, "");
-  static_assert(pow(std::complex<long double>(1.0l), ctag) == 5, "");
-#endif
+  return 0;
 }

>From f6b91cf18dbc41f86266415e51151e158b038c41 Mon Sep 17 00:00:00 2001
From: "A. Jiang" <de34 at live.cn>
Date: Sun, 27 Oct 2024 14:51:25 +0800
Subject: [PATCH 3/3] Adopt @ldionne's clarification for comments

Co-authored-by: Louis Dionne <ldionne.2 at gmail.com>
---
 .../libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp    | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp b/libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp
index bd38fb7edba091..1c790c283e4387 100644
--- a/libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp
+++ b/libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp
@@ -14,7 +14,9 @@
 
 // Test that these additional overloads are free from catching std::complex<non-floating-point>,
 // which is expected by several 3rd party libraries, see https://github.com/llvm/llvm-project/issues/109858.
-
+//
+// Note that we reserve the right to break this in the future if we have a reason to, but for the time being,
+// make sure we don't break this property unintentionally.
 #include <cassert>
 #include <cmath>
 #include <complex>



More information about the libcxx-commits mailing list