[clang] 8baba68 - [HIP] Support overloaded math functions for hipRTC

Yaxun Liu via cfe-commits cfe-commits at lists.llvm.org
Thu Apr 22 16:07:18 PDT 2021


Author: Yaxun (Sam) Liu
Date: 2021-04-22T19:06:51-04:00
New Revision: 8baba6890de74d33beb75646ebcbf168e949d578

URL: https://github.com/llvm/llvm-project/commit/8baba6890de74d33beb75646ebcbf168e949d578
DIFF: https://github.com/llvm/llvm-project/commit/8baba6890de74d33beb75646ebcbf168e949d578.diff

LOG: [HIP] Support overloaded math functions for hipRTC

Remove the dependence on standard C++ header
for overloaded math functions in HIP header
since standard C++ header is not available for hipRTC.

Reviewed by: Artem Belevich, Justin Lebar

Differential Revision: https://reviews.llvm.org/D100794

Added: 
    

Modified: 
    clang/lib/Headers/__clang_hip_cmath.h
    clang/lib/Headers/__clang_hip_runtime_wrapper.h
    clang/test/Headers/hip-header.hip

Removed: 
    


################################################################################
diff  --git a/clang/lib/Headers/__clang_hip_cmath.h b/clang/lib/Headers/__clang_hip_cmath.h
index 632d46e47f8b9..5d7b75ffdcc0f 100644
--- a/clang/lib/Headers/__clang_hip_cmath.h
+++ b/clang/lib/Headers/__clang_hip_cmath.h
@@ -22,7 +22,7 @@
 #endif
 #include <limits.h>
 #include <stdint.h>
-#endif // __HIPCC_RTC__
+#endif // !defined(__HIPCC_RTC__)
 
 #pragma push_macro("__DEVICE__")
 #define __DEVICE__ static __device__ inline __attribute__((always_inline))
@@ -36,6 +36,9 @@ __DEVICE__ long abs(long __n) { return ::labs(__n); }
 __DEVICE__ float fma(float __x, float __y, float __z) {
   return ::fmaf(__x, __y, __z);
 }
+#if !defined(__HIPCC_RTC__)
+// The value returned by fpclassify is platform dependent, therefore it is not
+// supported by hipRTC.
 __DEVICE__ int fpclassify(float __x) {
   return __builtin_fpclassify(FP_NAN, FP_INFINITE, FP_NORMAL, FP_SUBNORMAL,
                               FP_ZERO, __x);
@@ -44,6 +47,8 @@ __DEVICE__ int fpclassify(double __x) {
   return __builtin_fpclassify(FP_NAN, FP_INFINITE, FP_NORMAL, FP_SUBNORMAL,
                               FP_ZERO, __x);
 }
+#endif // !defined(__HIPCC_RTC__)
+
 __DEVICE__ float frexp(float __arg, int *__exp) {
   return ::frexpf(__arg, __exp);
 }
@@ -209,11 +214,117 @@ template <bool __B, class __T = void> struct __hip_enable_if {};
 
 template <class __T> struct __hip_enable_if<true, __T> { typedef __T type; };
 
+namespace __hip {
+template <class _Tp> struct is_integral {
+  enum { value = 0 };
+};
+template <> struct is_integral<bool> {
+  enum { value = 1 };
+};
+template <> struct is_integral<char> {
+  enum { value = 1 };
+};
+template <> struct is_integral<signed char> {
+  enum { value = 1 };
+};
+template <> struct is_integral<unsigned char> {
+  enum { value = 1 };
+};
+template <> struct is_integral<wchar_t> {
+  enum { value = 1 };
+};
+template <> struct is_integral<short> {
+  enum { value = 1 };
+};
+template <> struct is_integral<unsigned short> {
+  enum { value = 1 };
+};
+template <> struct is_integral<int> {
+  enum { value = 1 };
+};
+template <> struct is_integral<unsigned int> {
+  enum { value = 1 };
+};
+template <> struct is_integral<long> {
+  enum { value = 1 };
+};
+template <> struct is_integral<unsigned long> {
+  enum { value = 1 };
+};
+template <> struct is_integral<long long> {
+  enum { value = 1 };
+};
+template <> struct is_integral<unsigned long long> {
+  enum { value = 1 };
+};
+
+// ToDo: specializes is_arithmetic<_Float16>
+template <class _Tp> struct is_arithmetic {
+  enum { value = 0 };
+};
+template <> struct is_arithmetic<bool> {
+  enum { value = 1 };
+};
+template <> struct is_arithmetic<char> {
+  enum { value = 1 };
+};
+template <> struct is_arithmetic<signed char> {
+  enum { value = 1 };
+};
+template <> struct is_arithmetic<unsigned char> {
+  enum { value = 1 };
+};
+template <> struct is_arithmetic<wchar_t> {
+  enum { value = 1 };
+};
+template <> struct is_arithmetic<short> {
+  enum { value = 1 };
+};
+template <> struct is_arithmetic<unsigned short> {
+  enum { value = 1 };
+};
+template <> struct is_arithmetic<int> {
+  enum { value = 1 };
+};
+template <> struct is_arithmetic<unsigned int> {
+  enum { value = 1 };
+};
+template <> struct is_arithmetic<long> {
+  enum { value = 1 };
+};
+template <> struct is_arithmetic<unsigned long> {
+  enum { value = 1 };
+};
+template <> struct is_arithmetic<long long> {
+  enum { value = 1 };
+};
+template <> struct is_arithmetic<unsigned long long> {
+  enum { value = 1 };
+};
+template <> struct is_arithmetic<float> {
+  enum { value = 1 };
+};
+template <> struct is_arithmetic<double> {
+  enum { value = 1 };
+};
+
+struct true_type {
+  static const __constant__ bool value = true;
+};
+struct false_type {
+  static const __constant__ bool value = false;
+};
+
+template <typename __T, typename __U> struct is_same : public false_type {};
+template <typename __T> struct is_same<__T, __T> : public true_type {};
+
+template <typename __T> struct add_rvalue_reference { typedef __T &&type; };
+
+template <typename __T> typename add_rvalue_reference<__T>::type declval();
+
 // decltype is only available in C++11 and above.
 #if __cplusplus >= 201103L
 // __hip_promote
-namespace __hip {
-
 template <class _Tp> struct __numeric_type {
   static void __test(...);
   static _Float16 __test(_Float16);
@@ -229,8 +340,8 @@ template <class _Tp> struct __numeric_type {
   // No support for long double, use double instead.
   static double __test(long double);
 
-  typedef decltype(__test(std::declval<_Tp>())) type;
-  static const bool value = !std::is_same<type, void>::value;
+  typedef decltype(__test(declval<_Tp>())) type;
+  static const bool value = !is_same<type, void>::value;
 };
 
 template <> struct __numeric_type<void> { static const bool value = true; };
@@ -273,18 +384,17 @@ template <class _A1> class __promote_imp<_A1, void, void, true> {
 
 template <class _A1, class _A2 = void, class _A3 = void>
 class __promote : public __promote_imp<_A1, _A2, _A3> {};
-
-} // namespace __hip
 #endif //__cplusplus >= 201103L
+} // namespace __hip
 
 // __HIP_OVERLOAD1 is used to resolve function calls with integer argument to
 // avoid compilation error due to ambibuity. e.g. floor(5) is resolved with
 // floor(double).
 #define __HIP_OVERLOAD1(__retty, __fn)                                         \
   template <typename __T>                                                      \
-  __DEVICE__ typename __hip_enable_if<std::numeric_limits<__T>::is_integer,    \
-                                      __retty>::type                           \
-  __fn(__T __x) {                                                              \
+  __DEVICE__                                                                   \
+      typename __hip_enable_if<__hip::is_integral<__T>::value, __retty>::type  \
+      __fn(__T __x) {                                                          \
     return ::__fn((double)__x);                                                \
   }
 
@@ -295,8 +405,7 @@ class __promote : public __promote_imp<_A1, _A2, _A3> {};
 #define __HIP_OVERLOAD2(__retty, __fn)                                         \
   template <typename __T1, typename __T2>                                      \
   __DEVICE__ typename __hip_enable_if<                                         \
-      std::numeric_limits<__T1>::is_specialized &&                             \
-          std::numeric_limits<__T2>::is_specialized,                           \
+      __hip::is_arithmetic<__T1>::value && __hip::is_arithmetic<__T2>::value,  \
       typename __hip::__promote<__T1, __T2>::type>::type                       \
   __fn(__T1 __x, __T2 __y) {                                                   \
     typedef typename __hip::__promote<__T1, __T2>::type __result_type;         \
@@ -305,11 +414,10 @@ class __promote : public __promote_imp<_A1, _A2, _A3> {};
 #else
 #define __HIP_OVERLOAD2(__retty, __fn)                                         \
   template <typename __T1, typename __T2>                                      \
-  __DEVICE__                                                                   \
-      typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized &&    \
-                                   std::numeric_limits<__T2>::is_specialized,  \
-                               __retty>::type                                  \
-      __fn(__T1 __x, __T2 __y) {                                               \
+  __DEVICE__ typename __hip_enable_if<__hip::is_arithmetic<__T1>::value &&     \
+                                          __hip::is_arithmetic<__T2>::value,   \
+                                      __retty>::type                           \
+  __fn(__T1 __x, __T2 __y) {                                                   \
     return __fn((double)__x, (double)__y);                                     \
   }
 #endif
@@ -337,7 +445,9 @@ __HIP_OVERLOAD1(double, floor)
 __HIP_OVERLOAD2(double, fmax)
 __HIP_OVERLOAD2(double, fmin)
 __HIP_OVERLOAD2(double, fmod)
+#if !defined(__HIPCC_RTC__)
 __HIP_OVERLOAD1(int, fpclassify)
+#endif // !defined(__HIPCC_RTC__)
 __HIP_OVERLOAD2(double, hypot)
 __HIP_OVERLOAD1(int, ilogb)
 __HIP_OVERLOAD1(bool, isfinite)
@@ -383,9 +493,8 @@ __HIP_OVERLOAD2(double, min)
 #if __cplusplus >= 201103L
 template <typename __T1, typename __T2, typename __T3>
 __DEVICE__ typename __hip_enable_if<
-    std::numeric_limits<__T1>::is_specialized &&
-        std::numeric_limits<__T2>::is_specialized &&
-        std::numeric_limits<__T3>::is_specialized,
+    __hip::is_arithmetic<__T1>::value && __hip::is_arithmetic<__T2>::value &&
+        __hip::is_arithmetic<__T3>::value,
     typename __hip::__promote<__T1, __T2, __T3>::type>::type
 fma(__T1 __x, __T2 __y, __T3 __z) {
   typedef typename __hip::__promote<__T1, __T2, __T3>::type __result_type;
@@ -393,33 +502,32 @@ fma(__T1 __x, __T2 __y, __T3 __z) {
 }
 #else
 template <typename __T1, typename __T2, typename __T3>
-__DEVICE__
-    typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized &&
-                                 std::numeric_limits<__T2>::is_specialized &&
-                                 std::numeric_limits<__T3>::is_specialized,
-                             double>::type
-    fma(__T1 __x, __T2 __y, __T3 __z) {
+__DEVICE__ typename __hip_enable_if<__hip::is_arithmetic<__T1>::value &&
+                                        __hip::is_arithmetic<__T2>::value &&
+                                        __hip::is_arithmetic<__T3>::value,
+                                    double>::type
+fma(__T1 __x, __T2 __y, __T3 __z) {
   return ::fma((double)__x, (double)__y, (double)__z);
 }
 #endif
 
 template <typename __T>
 __DEVICE__
-    typename __hip_enable_if<std::numeric_limits<__T>::is_integer, double>::type
+    typename __hip_enable_if<__hip::is_integral<__T>::value, double>::type
     frexp(__T __x, int *__exp) {
   return ::frexp((double)__x, __exp);
 }
 
 template <typename __T>
 __DEVICE__
-    typename __hip_enable_if<std::numeric_limits<__T>::is_integer, double>::type
+    typename __hip_enable_if<__hip::is_integral<__T>::value, double>::type
     ldexp(__T __x, int __exp) {
   return ::ldexp((double)__x, __exp);
 }
 
 template <typename __T>
 __DEVICE__
-    typename __hip_enable_if<std::numeric_limits<__T>::is_integer, double>::type
+    typename __hip_enable_if<__hip::is_integral<__T>::value, double>::type
     modf(__T __x, double *__exp) {
   return ::modf((double)__x, __exp);
 }
@@ -427,8 +535,8 @@ __DEVICE__
 #if __cplusplus >= 201103L
 template <typename __T1, typename __T2>
 __DEVICE__
-    typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized &&
-                                 std::numeric_limits<__T2>::is_specialized,
+    typename __hip_enable_if<__hip::is_arithmetic<__T1>::value &&
+                                 __hip::is_arithmetic<__T2>::value,
                              typename __hip::__promote<__T1, __T2>::type>::type
     remquo(__T1 __x, __T2 __y, int *__quo) {
   typedef typename __hip::__promote<__T1, __T2>::type __result_type;
@@ -436,25 +544,24 @@ __DEVICE__
 }
 #else
 template <typename __T1, typename __T2>
-__DEVICE__
-    typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized &&
-                                 std::numeric_limits<__T2>::is_specialized,
-                             double>::type
-    remquo(__T1 __x, __T2 __y, int *__quo) {
+__DEVICE__ typename __hip_enable_if<__hip::is_arithmetic<__T1>::value &&
+                                        __hip::is_arithmetic<__T2>::value,
+                                    double>::type
+remquo(__T1 __x, __T2 __y, int *__quo) {
   return ::remquo((double)__x, (double)__y, __quo);
 }
 #endif
 
 template <typename __T>
 __DEVICE__
-    typename __hip_enable_if<std::numeric_limits<__T>::is_integer, double>::type
+    typename __hip_enable_if<__hip::is_integral<__T>::value, double>::type
     scalbln(__T __x, long int __exp) {
   return ::scalbln((double)__x, __exp);
 }
 
 template <typename __T>
 __DEVICE__
-    typename __hip_enable_if<std::numeric_limits<__T>::is_integer, double>::type
+    typename __hip_enable_if<__hip::is_integral<__T>::value, double>::type
     scalbn(__T __x, int __exp) {
   return ::scalbn((double)__x, __exp);
 }
@@ -469,6 +576,7 @@ __DEVICE__
 #endif // defined(__cplusplus)
 
 // Define these overloads inside the namespace our standard library uses.
+#if !defined(__HIPCC_RTC__)
 #ifdef _LIBCPP_BEGIN_NAMESPACE_STD
 _LIBCPP_BEGIN_NAMESPACE_STD
 #else
@@ -476,7 +584,7 @@ namespace std {
 #ifdef _GLIBCXX_BEGIN_NAMESPACE_VERSION
 _GLIBCXX_BEGIN_NAMESPACE_VERSION
 #endif
-#endif
+#endif // !defined(__HIPCC_RTC__)
 
 // Pull the new overloads we defined above into namespace std.
 // using ::abs; - This may be considered for C++.
@@ -624,8 +732,10 @@ _GLIBCXX_END_NAMESPACE_VERSION
 #endif
 } // namespace std
 #endif
+#endif // !defined(__HIPCC_RTC__)
 
 // Define device-side math functions from <ymath.h> on MSVC.
+#if !defined(__HIPCC_RTC__)
 #if defined(_MSC_VER)
 
 // Before VS2019, `<ymath.h>` is also included in `<limits>` and other headers.
@@ -659,6 +769,7 @@ __DEVICE__ __attribute__((overloadable)) float _FSinh(float x, float y) {
 }
 #endif // defined(__cplusplus)
 #endif // defined(_MSC_VER)
+#endif // !defined(__HIPCC_RTC__)
 
 #pragma pop_macro("__DEVICE__")
 

diff  --git a/clang/lib/Headers/__clang_hip_runtime_wrapper.h b/clang/lib/Headers/__clang_hip_runtime_wrapper.h
index 8ee5566b33cf8..58f148f9a2680 100644
--- a/clang/lib/Headers/__clang_hip_runtime_wrapper.h
+++ b/clang/lib/Headers/__clang_hip_runtime_wrapper.h
@@ -72,12 +72,13 @@ static inline __device__ void *free(void *__ptr) {
 #include <__clang_hip_libdevice_declares.h>
 #include <__clang_hip_math.h>
 
-#if !defined(__HIPCC_RTC__)
 #if !_OPENMP || __HIP_ENABLE_CUDA_WRAPPER_FOR_OPENMP__
+#if defined(__HIPCC_RTC__)
+#include <__clang_hip_cmath.h>
+#else
 #include <__clang_cuda_math_forward_declares.h>
 #include <__clang_hip_cmath.h>
 #include <__clang_cuda_complex_builtins.h>
-
 #include <algorithm>
 #include <complex>
 #include <new>

diff  --git a/clang/test/Headers/hip-header.hip b/clang/test/Headers/hip-header.hip
index c0ca394c613ce..323138613055f 100644
--- a/clang/test/Headers/hip-header.hip
+++ b/clang/test/Headers/hip-header.hip
@@ -10,18 +10,54 @@
 // RUN:   -internal-isystem %S/Inputs/include \
 // RUN:   -triple amdgcn-amd-amdhsa -aux-triple x86_64-unknown-unknown \
 // RUN:   -target-cpu gfx906 -emit-llvm %s -fcuda-is-device -o - \
-// RUN:   -D__HIPCC_RTC__ -std=c++14 | FileCheck %s
-
+// RUN:   -D__HIPCC_RTC__ -std=c++14 | FileCheck -check-prefixes=CHECK,CXX14 %s
 
 // expected-no-diagnostics
 
+struct Number {
+  __device__ Number(float _x) : x(_x) {}
+  float x;
+};
+
+#if __cplusplus >= 201103L
+// Check __hip::__numeric_type can be used with a class without default ctor.
+__device__ void test_numeric_type() {
+  int x = __hip::__numeric_type<Number>::value;
+}
+
+// ToDo: Fix __clang_hip_cmake.h to specialize __hip::is_arithmetic<_Float16>
+// to resolve fma(_Float16, _Float16, int) to fma(double, double, double)
+// instead of fma(_Float16, _Float16, _Float16).
+
+// CXX14-LABEL: define{{.*}}@_Z8test_fma
+// CXX14: call {{.*}}@__ocml_fma_f16
+__device__ double test_fma(_Float16 h, int i) {
+  return fma(h, h, i);
+}
+
+#endif
+
 // CHECK-LABEL: amdgpu_kernel void @_Z4kernPff
 __global__ void kern(float *x, float y) {
   *x = sin(y);
 }
 
 // CHECK-LABEL: define{{.*}} i64 @_Z11test_size_tv
-// CHEC: ret i64 8
+// CHECK: ret i64 8
 __device__ size_t test_size_t() {
   return sizeof(size_t);
 }
+
+// Check there is no ambiguity when calling overloaded math functions.
+
+// CHECK-LABEL: define{{.*}}@_Z10test_floorv
+// CHECK: call {{.*}}double @__ocml_floor_f64(double
+__device__ float test_floor() {
+  return floor(5);
+}
+
+// CHECK-LABEL: define{{.*}}@_Z8test_maxv
+// CHECK: call {{.*}}double @__ocml_fmax_f64(double {{.*}}, double
+__device__ float test_max() {
+  return max(5, 6.0);
+}


        


More information about the cfe-commits mailing list