[clang] ca5b315 - [HIP] Math Headers to use type promotion

Aaron En Ye Shi via cfe-commits cfe-commits at lists.llvm.org
Tue Nov 3 10:41:35 PST 2020


Author: Aaron En Ye Shi
Date: 2020-11-03T18:40:26Z
New Revision: ca5b31502c828f8e7160a77f54a5a131dc298005

URL: https://github.com/llvm/llvm-project/commit/ca5b31502c828f8e7160a77f54a5a131dc298005
DIFF: https://github.com/llvm/llvm-project/commit/ca5b31502c828f8e7160a77f54a5a131dc298005.diff

LOG: [HIP] Math Headers to use type promotion

Similar to libcxx implementation of cmath function
overloads, use type promotion templates to determine
return types of multi-argument math functions.

Fixes: SWDEV-256825

Reviewed By: tra, yaxunl

Differential Revision: https://reviews.llvm.org/D90409

Added: 
    

Modified: 
    clang/lib/Headers/__clang_hip_cmath.h

Removed: 
    


################################################################################
diff  --git a/clang/lib/Headers/__clang_hip_cmath.h b/clang/lib/Headers/__clang_hip_cmath.h
index fea799ead32f..00519a9795bc 100644
--- a/clang/lib/Headers/__clang_hip_cmath.h
+++ b/clang/lib/Headers/__clang_hip_cmath.h
@@ -16,6 +16,8 @@
 
 #if defined(__cplusplus)
 #include <limits>
+#include <type_traits>
+#include <utility>
 #endif
 #include <limits.h>
 #include <stdint.h>
@@ -205,6 +207,72 @@ template <bool __B, class __T = void> struct __hip_enable_if {};
 
 template <class __T> struct __hip_enable_if<true, __T> { typedef __T type; };
 
+// decltype is only available in C++11 and above.
+#if __cplusplus >= 201103L
+// __hip_promote
+namespace __hip {
+
+template <class _Tp> struct __numeric_type {
+  static void __test(...);
+  static _Float16 __test(_Float16);
+  static float __test(float);
+  static double __test(char);
+  static double __test(int);
+  static double __test(unsigned);
+  static double __test(long);
+  static double __test(unsigned long);
+  static double __test(long long);
+  static double __test(unsigned long long);
+  static double __test(double);
+
+  typedef decltype(__test(std::declval<_Tp>())) type;
+  static const bool value = !std::is_same<type, void>::value;
+};
+
+template <> struct __numeric_type<void> { static const bool value = true; };
+
+template <class _A1, class _A2 = void, class _A3 = void,
+          bool = __numeric_type<_A1>::value &&__numeric_type<_A2>::value
+              &&__numeric_type<_A3>::value>
+class __promote_imp {
+public:
+  static const bool value = false;
+};
+
+template <class _A1, class _A2, class _A3>
+class __promote_imp<_A1, _A2, _A3, true> {
+private:
+  typedef typename __promote_imp<_A1>::type __type1;
+  typedef typename __promote_imp<_A2>::type __type2;
+  typedef typename __promote_imp<_A3>::type __type3;
+
+public:
+  typedef decltype(__type1() + __type2() + __type3()) type;
+  static const bool value = true;
+};
+
+template <class _A1, class _A2> class __promote_imp<_A1, _A2, void, true> {
+private:
+  typedef typename __promote_imp<_A1>::type __type1;
+  typedef typename __promote_imp<_A2>::type __type2;
+
+public:
+  typedef decltype(__type1() + __type2()) type;
+  static const bool value = true;
+};
+
+template <class _A1> class __promote_imp<_A1, void, void, true> {
+public:
+  typedef typename __numeric_type<_A1>::type type;
+  static const bool value = true;
+};
+
+template <class _A1, class _A2 = void, class _A3 = void>
+class __promote : public __promote_imp<_A1, _A2, _A3> {};
+
+} // namespace __hip
+#endif //__cplusplus >= 201103L
+
 // __HIP_OVERLOAD1 is used to resolve function calls with integer argument to
 // avoid compilation error due to ambibuity. e.g. floor(5) is resolved with
 // floor(double).
@@ -219,6 +287,18 @@ template <class __T> struct __hip_enable_if<true, __T> { typedef __T type; };
 // __HIP_OVERLOAD2 is used to resolve function calls with mixed float/double
 // or integer argument to avoid compilation error due to ambibuity. e.g.
 // max(5.0f, 6.0) is resolved with max(double, double).
+#if __cplusplus >= 201103L
+#define __HIP_OVERLOAD2(__retty, __fn)                                         \
+  template <typename __T1, typename __T2>                                      \
+  __DEVICE__ typename __hip_enable_if<                                         \
+      std::numeric_limits<__T1>::is_specialized &&                             \
+          std::numeric_limits<__T2>::is_specialized,                           \
+      typename __hip::__promote<__T1, __T2>::type>::type                       \
+  __fn(__T1 __x, __T2 __y) {                                                   \
+    typedef typename __hip::__promote<__T1, __T2>::type __result_type;         \
+    return __fn((__result_type)__x, (__result_type)__y);                       \
+  }
+#else
 #define __HIP_OVERLOAD2(__retty, __fn)                                         \
   template <typename __T1, typename __T2>                                      \
   __DEVICE__                                                                   \
@@ -228,6 +308,7 @@ template <class __T> struct __hip_enable_if<true, __T> { typedef __T type; };
       __fn(__T1 __x, __T2 __y) {                                               \
     return __fn((double)__x, (double)__y);                                     \
   }
+#endif
 
 __HIP_OVERLOAD1(double, abs)
 __HIP_OVERLOAD1(double, acos)
@@ -296,6 +377,18 @@ __HIP_OVERLOAD2(double, max)
 __HIP_OVERLOAD2(double, min)
 
 // Additional Overloads that don't quite match HIP_OVERLOAD.
+#if __cplusplus >= 201103L
+template <typename __T1, typename __T2, typename __T3>
+__DEVICE__ typename __hip_enable_if<
+    std::numeric_limits<__T1>::is_specialized &&
+        std::numeric_limits<__T2>::is_specialized &&
+        std::numeric_limits<__T3>::is_specialized,
+    typename __hip::__promote<__T1, __T2, __T3>::type>::type
+fma(__T1 __x, __T2 __y, __T3 __z) {
+  typedef typename __hip::__promote<__T1, __T2, __T3>::type __result_type;
+  return ::fma((__result_type)__x, (__result_type)__y, (__result_type)__z);
+}
+#else
 template <typename __T1, typename __T2, typename __T3>
 __DEVICE__
     typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized &&
@@ -305,6 +398,7 @@ __DEVICE__
     fma(__T1 __x, __T2 __y, __T3 __z) {
   return ::fma((double)__x, (double)__y, (double)__z);
 }
+#endif
 
 template <typename __T>
 __DEVICE__
@@ -327,6 +421,17 @@ __DEVICE__
   return ::modf((double)__x, __exp);
 }
 
+#if __cplusplus >= 201103L
+template <typename __T1, typename __T2>
+__DEVICE__
+    typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized &&
+                                 std::numeric_limits<__T2>::is_specialized,
+                             typename __hip::__promote<__T1, __T2>::type>::type
+    remquo(__T1 __x, __T2 __y, int *__quo) {
+  typedef typename __hip::__promote<__T1, __T2>::type __result_type;
+  return ::remquo((__result_type)__x, (__result_type)__y, __quo);
+}
+#else
 template <typename __T1, typename __T2>
 __DEVICE__
     typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized &&
@@ -335,6 +440,7 @@ __DEVICE__
     remquo(__T1 __x, __T2 __y, int *__quo) {
   return ::remquo((double)__x, (double)__y, __quo);
 }
+#endif
 
 template <typename __T>
 __DEVICE__


        


More information about the cfe-commits mailing list