[clang] 56069b5 - [OpenMP] Support `std::complex` math functions in target regions

Johannes Doerfert via cfe-commits cfe-commits at lists.llvm.org
Wed Sep 16 11:40:10 PDT 2020


Author: Johannes Doerfert
Date: 2020-09-16T13:37:10-05:00
New Revision: 56069b5c71ca78749aa983c1e9de6f1e4c049f4b

URL: https://github.com/llvm/llvm-project/commit/56069b5c71ca78749aa983c1e9de6f1e4c049f4b
DIFF: https://github.com/llvm/llvm-project/commit/56069b5c71ca78749aa983c1e9de6f1e4c049f4b.diff

LOG: [OpenMP] Support `std::complex` math functions in target regions

The last (big) missing piece to get "math" working in OpenMP target
regions (that I know of) was complex math functions, e.g.,
`std::sin(std::complex<double>)`. With this patch we overload the system
template functions for these operations with versions that have been
distilled from `libcxx/include/complex`. We use the same
  `omp begin/end declare variant`
mechanism we use for other math functions before, except that we this
time overload templates (via D85735).

Reviewed By: JonChesterfield

Differential Revision: https://reviews.llvm.org/D85777

Added: 
    clang/lib/Headers/openmp_wrappers/complex_cmath.h
    clang/test/Headers/Inputs/include/type_traits

Modified: 
    clang/lib/Headers/CMakeLists.txt
    clang/lib/Headers/openmp_wrappers/complex
    clang/test/Headers/Inputs/include/complex
    clang/test/Headers/nvptx_device_math_complex.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/Headers/CMakeLists.txt b/clang/lib/Headers/CMakeLists.txt
index 0692fe75a441..a9761f049067 100644
--- a/clang/lib/Headers/CMakeLists.txt
+++ b/clang/lib/Headers/CMakeLists.txt
@@ -154,6 +154,7 @@ set(openmp_wrapper_files
   openmp_wrappers/complex.h
   openmp_wrappers/complex
   openmp_wrappers/__clang_openmp_device_functions.h
+  openmp_wrappers/complex_cmath.h
   openmp_wrappers/new
 )
 

diff  --git a/clang/lib/Headers/openmp_wrappers/complex b/clang/lib/Headers/openmp_wrappers/complex
index 1ed0b14879ef..306ffe208053 100644
--- a/clang/lib/Headers/openmp_wrappers/complex
+++ b/clang/lib/Headers/openmp_wrappers/complex
@@ -23,3 +23,28 @@
 
 // Grab the host header too.
 #include_next <complex>
+
+
+#ifdef __cplusplus
+
+// If we are compiling against libc++, the macro _LIBCPP_STD_VER should be set
+// after including <cmath> above. Since the complex header we use is a
+// simplified version of the libc++, we don't need it in this case. If we
+// compile against libstdc++, or any other standard library, we will overload
+// the (hopefully template) functions in the <complex> header with the ones we
+// got from libc++ which decomposes math functions, like `std::sin`, into
+// arithmetic and calls to non-complex functions, all of which we can then
+// handle.
+#ifndef _LIBCPP_STD_VER
+
+#pragma omp begin declare variant match(                                       \
+    device = {arch(nvptx, nvptx64)},                                           \
+    implementation = {extension(match_any, allow_templates)})
+
+#include <complex_cmath.h>
+
+#pragma omp end declare variant
+
+#endif
+
+#endif

diff  --git a/clang/lib/Headers/openmp_wrappers/complex_cmath.h b/clang/lib/Headers/openmp_wrappers/complex_cmath.h
new file mode 100644
index 000000000000..e3d9aebbbc24
--- /dev/null
+++ b/clang/lib/Headers/openmp_wrappers/complex_cmath.h
@@ -0,0 +1,388 @@
+//===------------------------- __complex_cmath.h --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// std::complex header copied from the libcxx source and simplified for use in
+// OpenMP target offload regions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _OPENMP
+#error "This file is for OpenMP compilation only."
+#endif
+
+#ifndef __cplusplus
+#error "This file is for C++ compilation only."
+#endif
+
+#ifndef _LIBCPP_COMPLEX
+#define _LIBCPP_COMPLEX
+
+#include <cmath>
+#include <type_traits>
+
+#define __DEVICE__ static constexpr __attribute__((nothrow))
+
+namespace std {
+
+// abs
+
+template <class _Tp> __DEVICE__ _Tp abs(const std::complex<_Tp> &__c) {
+  return hypot(__c.real(), __c.imag());
+}
+
+// arg
+
+template <class _Tp> __DEVICE__ _Tp arg(const std::complex<_Tp> &__c) {
+  return atan2(__c.imag(), __c.real());
+}
+
+template <class _Tp>
+typename enable_if<is_integral<_Tp>::value || is_same<_Tp, double>::value,
+                   double>::type
+arg(_Tp __re) {
+  return atan2(0., __re);
+}
+
+template <class _Tp>
+typename enable_if<is_same<_Tp, float>::value, float>::type arg(_Tp __re) {
+  return atan2f(0.F, __re);
+}
+
+// norm
+
+template <class _Tp> __DEVICE__ _Tp norm(const std::complex<_Tp> &__c) {
+  if (std::isinf(__c.real()))
+    return abs(__c.real());
+  if (std::isinf(__c.imag()))
+    return abs(__c.imag());
+  return __c.real() * __c.real() + __c.imag() * __c.imag();
+}
+
+// conj
+
+template <class _Tp> std::complex<_Tp> conj(const std::complex<_Tp> &__c) {
+  return std::complex<_Tp>(__c.real(), -__c.imag());
+}
+
+// proj
+
+template <class _Tp> std::complex<_Tp> proj(const std::complex<_Tp> &__c) {
+  std::complex<_Tp> __r = __c;
+  if (std::isinf(__c.real()) || std::isinf(__c.imag()))
+    __r = std::complex<_Tp>(INFINITY, copysign(_Tp(0), __c.imag()));
+  return __r;
+}
+
+// polar
+
+template <class _Tp>
+complex<_Tp> polar(const _Tp &__rho, const _Tp &__theta = _Tp()) {
+  if (std::isnan(__rho) || signbit(__rho))
+    return std::complex<_Tp>(_Tp(NAN), _Tp(NAN));
+  if (std::isnan(__theta)) {
+    if (std::isinf(__rho))
+      return std::complex<_Tp>(__rho, __theta);
+    return std::complex<_Tp>(__theta, __theta);
+  }
+  if (std::isinf(__theta)) {
+    if (std::isinf(__rho))
+      return std::complex<_Tp>(__rho, _Tp(NAN));
+    return std::complex<_Tp>(_Tp(NAN), _Tp(NAN));
+  }
+  _Tp __x = __rho * cos(__theta);
+  if (std::isnan(__x))
+    __x = 0;
+  _Tp __y = __rho * sin(__theta);
+  if (std::isnan(__y))
+    __y = 0;
+  return std::complex<_Tp>(__x, __y);
+}
+
+// log
+
+template <class _Tp> std::complex<_Tp> log(const std::complex<_Tp> &__x) {
+  return std::complex<_Tp>(log(abs(__x)), arg(__x));
+}
+
+// log10
+
+template <class _Tp> std::complex<_Tp> log10(const std::complex<_Tp> &__x) {
+  return log(__x) / log(_Tp(10));
+}
+
+// sqrt
+
+template <class _Tp>
+__DEVICE__ std::complex<_Tp> sqrt(const std::complex<_Tp> &__x) {
+  if (std::isinf(__x.imag()))
+    return std::complex<_Tp>(_Tp(INFINITY), __x.imag());
+  if (std::isinf(__x.real())) {
+    if (__x.real() > _Tp(0))
+      return std::complex<_Tp>(__x.real(), std::isnan(__x.imag())
+                                               ? __x.imag()
+                                               : copysign(_Tp(0), __x.imag()));
+    return std::complex<_Tp>(std::isnan(__x.imag()) ? __x.imag() : _Tp(0),
+                             copysign(__x.real(), __x.imag()));
+  }
+  return polar(sqrt(abs(__x)), arg(__x) / _Tp(2));
+}
+
+// exp
+
+template <class _Tp>
+__DEVICE__ std::complex<_Tp> exp(const std::complex<_Tp> &__x) {
+  _Tp __i = __x.imag();
+  if (std::isinf(__x.real())) {
+    if (__x.real() < _Tp(0)) {
+      if (!std::isfinite(__i))
+        __i = _Tp(1);
+    } else if (__i == 0 || !std::isfinite(__i)) {
+      if (std::isinf(__i))
+        __i = _Tp(NAN);
+      return std::complex<_Tp>(__x.real(), __i);
+    }
+  } else if (std::isnan(__x.real()) && __x.imag() == 0)
+    return __x;
+  _Tp __e = exp(__x.real());
+  return std::complex<_Tp>(__e * cos(__i), __e * sin(__i));
+}
+
+// pow
+
+template <class _Tp>
+std::complex<_Tp> pow(const std::complex<_Tp> &__x,
+                      const std::complex<_Tp> &__y) {
+  return exp(__y * log(__x));
+}
+
+// __sqr, computes pow(x, 2)
+
+template <class _Tp> std::complex<_Tp> __sqr(const std::complex<_Tp> &__x) {
+  return std::complex<_Tp>((__x.real() - __x.imag()) *
+                               (__x.real() + __x.imag()),
+                           _Tp(2) * __x.real() * __x.imag());
+}
+
+// asinh
+
+template <class _Tp>
+__DEVICE__ std::complex<_Tp> asinh(const std::complex<_Tp> &__x) {
+  const _Tp __pi(atan2(+0., -0.));
+  if (std::isinf(__x.real())) {
+    if (std::isnan(__x.imag()))
+      return __x;
+    if (std::isinf(__x.imag()))
+      return std::complex<_Tp>(__x.real(),
+                               copysign(__pi * _Tp(0.25), __x.imag()));
+    return std::complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag()));
+  }
+  if (std::isnan(__x.real())) {
+    if (std::isinf(__x.imag()))
+      return std::complex<_Tp>(__x.imag(), __x.real());
+    if (__x.imag() == 0)
+      return __x;
+    return std::complex<_Tp>(__x.real(), __x.real());
+  }
+  if (std::isinf(__x.imag()))
+    return std::complex<_Tp>(copysign(__x.imag(), __x.real()),
+                             copysign(__pi / _Tp(2), __x.imag()));
+  std::complex<_Tp> __z = log(__x + sqrt(__sqr(__x) + _Tp(1)));
+  return std::complex<_Tp>(copysign(__z.real(), __x.real()),
+                           copysign(__z.imag(), __x.imag()));
+}
+
+// acosh
+
+template <class _Tp>
+__DEVICE__ std::complex<_Tp> acosh(const std::complex<_Tp> &__x) {
+  const _Tp __pi(atan2(+0., -0.));
+  if (std::isinf(__x.real())) {
+    if (std::isnan(__x.imag()))
+      return std::complex<_Tp>(abs(__x.real()), __x.imag());
+    if (std::isinf(__x.imag())) {
+      if (__x.real() > 0)
+        return std::complex<_Tp>(__x.real(),
+                                 copysign(__pi * _Tp(0.25), __x.imag()));
+      else
+        return std::complex<_Tp>(-__x.real(),
+                                 copysign(__pi * _Tp(0.75), __x.imag()));
+    }
+    if (__x.real() < 0)
+      return std::complex<_Tp>(-__x.real(), copysign(__pi, __x.imag()));
+    return std::complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag()));
+  }
+  if (std::isnan(__x.real())) {
+    if (std::isinf(__x.imag()))
+      return std::complex<_Tp>(abs(__x.imag()), __x.real());
+    return std::complex<_Tp>(__x.real(), __x.real());
+  }
+  if (std::isinf(__x.imag()))
+    return std::complex<_Tp>(abs(__x.imag()),
+                             copysign(__pi / _Tp(2), __x.imag()));
+  std::complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1)));
+  return std::complex<_Tp>(copysign(__z.real(), _Tp(0)),
+                           copysign(__z.imag(), __x.imag()));
+}
+
+// atanh
+
+template <class _Tp>
+__DEVICE__ std::complex<_Tp> atanh(const std::complex<_Tp> &__x) {
+  const _Tp __pi(atan2(+0., -0.));
+  if (std::isinf(__x.imag())) {
+    return std::complex<_Tp>(copysign(_Tp(0), __x.real()),
+                             copysign(__pi / _Tp(2), __x.imag()));
+  }
+  if (std::isnan(__x.imag())) {
+    if (std::isinf(__x.real()) || __x.real() == 0)
+      return std::complex<_Tp>(copysign(_Tp(0), __x.real()), __x.imag());
+    return std::complex<_Tp>(__x.imag(), __x.imag());
+  }
+  if (std::isnan(__x.real())) {
+    return std::complex<_Tp>(__x.real(), __x.real());
+  }
+  if (std::isinf(__x.real())) {
+    return std::complex<_Tp>(copysign(_Tp(0), __x.real()),
+                             copysign(__pi / _Tp(2), __x.imag()));
+  }
+  if (abs(__x.real()) == _Tp(1) && __x.imag() == _Tp(0)) {
+    return std::complex<_Tp>(copysign(_Tp(INFINITY), __x.real()),
+                             copysign(_Tp(0), __x.imag()));
+  }
+  std::complex<_Tp> __z = log((_Tp(1) + __x) / (_Tp(1) - __x)) / _Tp(2);
+  return std::complex<_Tp>(copysign(__z.real(), __x.real()),
+                           copysign(__z.imag(), __x.imag()));
+}
+
+// sinh
+
+template <class _Tp>
+__DEVICE__ std::complex<_Tp> sinh(const std::complex<_Tp> &__x) {
+  if (std::isinf(__x.real()) && !std::isfinite(__x.imag()))
+    return std::complex<_Tp>(__x.real(), _Tp(NAN));
+  if (__x.real() == 0 && !std::isfinite(__x.imag()))
+    return std::complex<_Tp>(__x.real(), _Tp(NAN));
+  if (__x.imag() == 0 && !std::isfinite(__x.real()))
+    return __x;
+  return std::complex<_Tp>(sinh(__x.real()) * cos(__x.imag()),
+                           cosh(__x.real()) * sin(__x.imag()));
+}
+
+// cosh
+
+template <class _Tp>
+__DEVICE__ std::complex<_Tp> cosh(const std::complex<_Tp> &__x) {
+  if (std::isinf(__x.real()) && !std::isfinite(__x.imag()))
+    return std::complex<_Tp>(abs(__x.real()), _Tp(NAN));
+  if (__x.real() == 0 && !std::isfinite(__x.imag()))
+    return std::complex<_Tp>(_Tp(NAN), __x.real());
+  if (__x.real() == 0 && __x.imag() == 0)
+    return std::complex<_Tp>(_Tp(1), __x.imag());
+  if (__x.imag() == 0 && !std::isfinite(__x.real()))
+    return std::complex<_Tp>(abs(__x.real()), __x.imag());
+  return std::complex<_Tp>(cosh(__x.real()) * cos(__x.imag()),
+                           sinh(__x.real()) * sin(__x.imag()));
+}
+
+// tanh
+
+template <class _Tp>
+__DEVICE__ std::complex<_Tp> tanh(const std::complex<_Tp> &__x) {
+  if (std::isinf(__x.real())) {
+    if (!std::isfinite(__x.imag()))
+      return std::complex<_Tp>(_Tp(1), _Tp(0));
+    return std::complex<_Tp>(_Tp(1),
+                             copysign(_Tp(0), sin(_Tp(2) * __x.imag())));
+  }
+  if (std::isnan(__x.real()) && __x.imag() == 0)
+    return __x;
+  _Tp __2r(_Tp(2) * __x.real());
+  _Tp __2i(_Tp(2) * __x.imag());
+  _Tp __d(cosh(__2r) + cos(__2i));
+  _Tp __2rsh(sinh(__2r));
+  if (std::isinf(__2rsh) && std::isinf(__d))
+    return std::complex<_Tp>(__2rsh > _Tp(0) ? _Tp(1) : _Tp(-1),
+                             __2i > _Tp(0) ? _Tp(0) : _Tp(-0.));
+  return std::complex<_Tp>(__2rsh / __d, sin(__2i) / __d);
+}
+
+// asin
+
+template <class _Tp>
+__DEVICE__ std::complex<_Tp> asin(const std::complex<_Tp> &__x) {
+  std::complex<_Tp> __z = asinh(complex<_Tp>(-__x.imag(), __x.real()));
+  return std::complex<_Tp>(__z.imag(), -__z.real());
+}
+
+// acos
+
+template <class _Tp>
+__DEVICE__ std::complex<_Tp> acos(const std::complex<_Tp> &__x) {
+  const _Tp __pi(atan2(+0., -0.));
+  if (std::isinf(__x.real())) {
+    if (std::isnan(__x.imag()))
+      return std::complex<_Tp>(__x.imag(), __x.real());
+    if (std::isinf(__x.imag())) {
+      if (__x.real() < _Tp(0))
+        return std::complex<_Tp>(_Tp(0.75) * __pi, -__x.imag());
+      return std::complex<_Tp>(_Tp(0.25) * __pi, -__x.imag());
+    }
+    if (__x.real() < _Tp(0))
+      return std::complex<_Tp>(__pi,
+                               signbit(__x.imag()) ? -__x.real() : __x.real());
+    return std::complex<_Tp>(_Tp(0),
+                             signbit(__x.imag()) ? __x.real() : -__x.real());
+  }
+  if (std::isnan(__x.real())) {
+    if (std::isinf(__x.imag()))
+      return std::complex<_Tp>(__x.real(), -__x.imag());
+    return std::complex<_Tp>(__x.real(), __x.real());
+  }
+  if (std::isinf(__x.imag()))
+    return std::complex<_Tp>(__pi / _Tp(2), -__x.imag());
+  if (__x.real() == 0 && (__x.imag() == 0 || isnan(__x.imag())))
+    return std::complex<_Tp>(__pi / _Tp(2), -__x.imag());
+  std::complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1)));
+  if (signbit(__x.imag()))
+    return std::complex<_Tp>(abs(__z.imag()), abs(__z.real()));
+  return std::complex<_Tp>(abs(__z.imag()), -abs(__z.real()));
+}
+
+// atan
+
+template <class _Tp>
+__DEVICE__ std::complex<_Tp> atan(const std::complex<_Tp> &__x) {
+  std::complex<_Tp> __z = atanh(complex<_Tp>(-__x.imag(), __x.real()));
+  return std::complex<_Tp>(__z.imag(), -__z.real());
+}
+
+// sin
+
+template <class _Tp>
+__DEVICE__ std::complex<_Tp> sin(const std::complex<_Tp> &__x) {
+  std::complex<_Tp> __z = sinh(complex<_Tp>(-__x.imag(), __x.real()));
+  return std::complex<_Tp>(__z.imag(), -__z.real());
+}
+
+// cos
+
+template <class _Tp> std::complex<_Tp> cos(const std::complex<_Tp> &__x) {
+  return cosh(complex<_Tp>(-__x.imag(), __x.real()));
+}
+
+// tan
+
+template <class _Tp>
+__DEVICE__ std::complex<_Tp> tan(const std::complex<_Tp> &__x) {
+  std::complex<_Tp> __z = tanh(complex<_Tp>(-__x.imag(), __x.real()));
+  return std::complex<_Tp>(__z.imag(), -__z.real());
+}
+
+} // namespace std
+
+#endif

diff  --git a/clang/test/Headers/Inputs/include/complex b/clang/test/Headers/Inputs/include/complex
index f3aefab7954b..bd43cd952d7c 100644
--- a/clang/test/Headers/Inputs/include/complex
+++ b/clang/test/Headers/Inputs/include/complex
@@ -3,6 +3,7 @@
 #include <cmath>
 
 #define INFINITY (__builtin_inff())
+#define NAN (__builtin_nanf (""))
 
 namespace std {
 
@@ -298,4 +299,114 @@ operator!=(const _Tp &__x, const complex<_Tp> &__y) {
   return !(__x == __y);
 }
 
+template <class _Tp> _Tp abs(const std::complex<_Tp> &__c);
+
+// arg
+
+template <class _Tp> _Tp arg(const std::complex<_Tp> &__c);
+
+// norm
+
+template <class _Tp> _Tp norm(const std::complex<_Tp> &__c);
+
+// conj
+
+template <class _Tp> std::complex<_Tp> conj(const std::complex<_Tp> &__c);
+
+// proj
+
+template <class _Tp> std::complex<_Tp> proj(const std::complex<_Tp> &__c);
+
+// polar
+
+template <class _Tp>
+complex<_Tp> polar(const _Tp &__rho, const _Tp &__theta = _Tp());
+
+// log
+
+template <class _Tp> std::complex<_Tp> log(const std::complex<_Tp> &__x);
+
+// log10
+
+template <class _Tp> std::complex<_Tp> log10(const std::complex<_Tp> &__x);
+
+// sqrt
+
+template <class _Tp>
+std::complex<_Tp> sqrt(const std::complex<_Tp> &__x);
+
+// exp
+
+template <class _Tp>
+std::complex<_Tp> exp(const std::complex<_Tp> &__x);
+
+// pow
+
+template <class _Tp>
+std::complex<_Tp> pow(const std::complex<_Tp> &__x,
+                      const std::complex<_Tp> &__y);
+
+// __sqr, computes pow(x, 2)
+
+template <class _Tp> std::complex<_Tp> __sqr(const std::complex<_Tp> &__x);
+
+// asinh
+
+template <class _Tp>
+std::complex<_Tp> asinh(const std::complex<_Tp> &__x);
+
+// acosh
+
+template <class _Tp>
+std::complex<_Tp> acosh(const std::complex<_Tp> &__x);
+
+// atanh
+
+template <class _Tp>
+std::complex<_Tp> atanh(const std::complex<_Tp> &__x);
+
+// sinh
+
+template <class _Tp>
+std::complex<_Tp> sinh(const std::complex<_Tp> &__x);
+
+// cosh
+
+template <class _Tp>
+std::complex<_Tp> cosh(const std::complex<_Tp> &__x);
+
+// tanh
+
+template <class _Tp>
+std::complex<_Tp> tanh(const std::complex<_Tp> &__x);
+
+// asin
+
+template <class _Tp>
+std::complex<_Tp> asin(const std::complex<_Tp> &__x);
+
+// acos
+
+template <class _Tp>
+std::complex<_Tp> acos(const std::complex<_Tp> &__x);
+
+// atan
+
+template <class _Tp>
+std::complex<_Tp> atan(const std::complex<_Tp> &__x);
+
+// sin
+
+template <class _Tp>
+std::complex<_Tp> sin(const std::complex<_Tp> &__x);
+
+// cos
+
+template <class _Tp> std::complex<_Tp> cos(const std::complex<_Tp> &__x);
+
+// tan
+
+template <class _Tp>
+std::complex<_Tp> tan(const std::complex<_Tp> &__x);
+
 } // namespace std

diff  --git a/clang/test/Headers/Inputs/include/type_traits b/clang/test/Headers/Inputs/include/type_traits
new file mode 100644
index 000000000000..9fd02d51eff1
--- /dev/null
+++ b/clang/test/Headers/Inputs/include/type_traits
@@ -0,0 +1,43 @@
+/// Copied from libcxx type_traits and simplified
+
+#pragma once
+
+namespace std {
+
+template <class _Tp, _Tp __v>
+struct integral_constant {
+  static const _Tp value = __v;
+  typedef _Tp value_type;
+  typedef integral_constant type;
+};
+
+typedef integral_constant<bool, true> true_type;
+typedef integral_constant<bool, false> false_type;
+
+// is_same, functional
+template <class _Tp, class _Up> struct is_same : public false_type {};
+template <class _Tp> struct is_same<_Tp, _Tp> : public true_type {};
+
+// is_integral, for some types.
+template <class _Tp> struct is_integral
+    : public integral_constant<bool, false> {};
+template <> struct is_integral<bool>
+    : public integral_constant<bool, true> {};
+template <> struct is_integral<char>
+    : public integral_constant<bool, true> {};
+template <> struct is_integral<short>
+    : public integral_constant<bool, true> {};
+template <> struct is_integral<int>
+    : public integral_constant<bool, true> {};
+template <> struct is_integral<long>
+    : public integral_constant<bool, true> {};
+template <> struct is_integral<long long>
+    : public integral_constant<bool, true> {};
+
+// enable_if, functional
+template <bool _C, typename _Tp> struct enable_if{};
+template <typename _Tp> struct enable_if<true, _Tp>{
+  using type = _Tp;
+};
+
+}

diff  --git a/clang/test/Headers/nvptx_device_math_complex.cpp b/clang/test/Headers/nvptx_device_math_complex.cpp
index e4b78deb05d7..688fd5d101ea 100644
--- a/clang/test/Headers/nvptx_device_math_complex.cpp
+++ b/clang/test/Headers/nvptx_device_math_complex.cpp
@@ -3,6 +3,7 @@
 // RUN: %clang_cc1 -verify -internal-isystem %S/../../lib/Headers/openmp_wrappers -include __clang_openmp_device_functions.h -internal-isystem %S/Inputs/include -fopenmp -x c++ -triple nvptx64-unknown-unknown -fopenmp-targets=nvptx64-nvidia-cuda -emit-llvm %s -fopenmp-is-device -fopenmp-host-ir-file-path %t-ppc-host.bc -aux-triple powerpc64le-unknown-unknown -o - | FileCheck %s
 // expected-no-diagnostics
 
+#include <cmath>
 #include <complex>
 
 // CHECK: define weak {{.*}} @__muldc3
@@ -33,6 +34,12 @@
 // CHECK-DAG: call float @__nv_fabsf(
 // CHECK-DAG: call float @__nv_logbf(
 
+// We actually check that there are no declarations of non-OpenMP functions.
+// That is, as long as we don't call an unkown function with a name that
+// doesn't start with '__' we are good :)
+
+// CHECK-NOT: declare.*@[^_]
+
 void test_scmplx(std::complex<float> a) {
 #pragma omp target
   {
@@ -46,3 +53,35 @@ void test_dcmplx(std::complex<double> a) {
     (void)(a * (a / a));
   }
 }
+
+template <typename T>
+std::complex<T> test_template_math_calls(std::complex<T> a) {
+  decltype(a) r = a;
+#pragma omp target
+  {
+    r = std::sin(r);
+    r = std::cos(r);
+    r = std::exp(r);
+    r = std::atan(r);
+    r = std::acos(r);
+  }
+  return r;
+}
+
+std::complex<float> test_scall(std::complex<float> a) {
+  decltype(a) r;
+#pragma omp target
+  {
+    r = std::sin(a);
+  }
+  return test_template_math_calls(r);
+}
+
+std::complex<double> test_dcall(std::complex<double> a) {
+  decltype(a) r;
+#pragma omp target
+  {
+    r = std::exp(a);
+  }
+  return test_template_math_calls(r);
+}


        


More information about the cfe-commits mailing list