[flang-commits] [flang] f8a9f43 - [flang][runtime] Enable real/complex kind 10 and 16 variants of dot_product.

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Wed Aug 31 15:19:57 PDT 2022


Author: Slava Zakharin
Date: 2022-08-31T15:17:17-07:00
New Revision: f8a9f43ef7affb7991e60cdd5ce93d2566f5b2e4

URL: https://github.com/llvm/llvm-project/commit/f8a9f43ef7affb7991e60cdd5ce93d2566f5b2e4
DIFF: https://github.com/llvm/llvm-project/commit/f8a9f43ef7affb7991e60cdd5ce93d2566f5b2e4.diff

LOG: [flang][runtime] Enable real/complex kind 10 and 16 variants of dot_product.

HasCppTypeFor<> used to evaluate to false always, so kind 10 and 16
variants of dot_product were not instantiated even though the host
supported 80- and 128-bit real and complex data types.
In addition, HAS_FLOAT128 was not enabling complex kind=16 variant
of dot_product. This is fixed now.

Note that the change for HasCppTypeFor<> may also affect other
functions such as matmul, i.e. kind 10 and 16 variants of them
may become available now (depending on the build host).

Differential Revision: https://reviews.llvm.org/D133051

Added: 
    

Modified: 
    flang/include/flang/Runtime/cpp-type.h
    flang/runtime/dot-product.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Runtime/cpp-type.h b/flang/include/flang/Runtime/cpp-type.h
index aa4b6f360124e..00af2c115484e 100644
--- a/flang/include/flang/Runtime/cpp-type.h
+++ b/flang/include/flang/Runtime/cpp-type.h
@@ -23,14 +23,14 @@ namespace Fortran::runtime {
 
 using common::TypeCategory;
 
-template <TypeCategory CAT, int KIND> struct CppTypeForHelper {};
+template <TypeCategory CAT, int KIND> struct CppTypeForHelper {
+  using type = void;
+};
 template <TypeCategory CAT, int KIND>
 using CppTypeFor = typename CppTypeForHelper<CAT, KIND>::type;
 
-template <TypeCategory CAT, int KIND, bool SFINAE = false>
-constexpr bool HasCppTypeFor{false};
 template <TypeCategory CAT, int KIND>
-constexpr bool HasCppTypeFor<CAT, KIND, true>{
+constexpr bool HasCppTypeFor{
     !std::is_void_v<typename CppTypeForHelper<CAT, KIND>::type>};
 
 template <int KIND> struct CppTypeForHelper<TypeCategory::Integer, KIND> {

diff  --git a/flang/runtime/dot-product.cpp b/flang/runtime/dot-product.cpp
index 2f9debbfccaa0..857ed6759817a 100644
--- a/flang/runtime/dot-product.cpp
+++ b/flang/runtime/dot-product.cpp
@@ -147,24 +147,24 @@ template <TypeCategory RCAT, int RKIND> struct DotProduct {
 };
 
 extern "C" {
-std::int8_t RTNAME(DotProductInteger1)(
+CppTypeFor<TypeCategory::Integer, 1> RTNAME(DotProductInteger1)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
   return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
 }
-std::int16_t RTNAME(DotProductInteger2)(
+CppTypeFor<TypeCategory::Integer, 2> RTNAME(DotProductInteger2)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
   return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
 }
-std::int32_t RTNAME(DotProductInteger4)(
+CppTypeFor<TypeCategory::Integer, 4> RTNAME(DotProductInteger4)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
   return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
 }
-std::int64_t RTNAME(DotProductInteger8)(
+CppTypeFor<TypeCategory::Integer, 8> RTNAME(DotProductInteger8)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
   return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
 }
 #ifdef __SIZEOF_INT128__
-common::int128_t RTNAME(DotProductInteger16)(
+CppTypeFor<TypeCategory::Integer, 16> RTNAME(DotProductInteger16)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
   return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
 }
@@ -172,16 +172,16 @@ common::int128_t RTNAME(DotProductInteger16)(
 
 // TODO: REAL/COMPLEX(2 & 3)
 // Intermediate results and operations are at least 64 bits
-float RTNAME(DotProductReal4)(
+CppTypeFor<TypeCategory::Real, 4> RTNAME(DotProductReal4)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
   return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
 }
-double RTNAME(DotProductReal8)(
+CppTypeFor<TypeCategory::Real, 8> RTNAME(DotProductReal8)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
   return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
 }
 #if LDBL_MANT_DIG == 64
-long double RTNAME(DotProductReal10)(
+CppTypeFor<TypeCategory::Real, 10> RTNAME(DotProductReal10)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
   return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
 }
@@ -193,24 +193,25 @@ CppTypeFor<TypeCategory::Real, 16> RTNAME(DotProductReal16)(
 }
 #endif
 
-void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
+void RTNAME(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  auto z{DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line)};
-  result = std::complex<float>{
-      static_cast<float>(z.real()), static_cast<float>(z.imag())};
+  result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line);
 }
-void RTNAME(CppDotProductComplex8)(std::complex<double> &result,
+void RTNAME(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
   result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
 }
 #if LDBL_MANT_DIG == 64
-void RTNAME(CppDotProductComplex10)(std::complex<long double> &result,
-    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+void RTNAME(CppDotProductComplex10)(
+    CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x,
+    const Descriptor &y, const char *source, int line) {
   result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
 }
-#elif LDBL_MANT_DIG == 113
-void RTNAME(CppDotProductComplex16)(std::complex<CppFloat128Type> &result,
-    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+void RTNAME(CppDotProductComplex16)(
+    CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x,
+    const Descriptor &y, const char *source, int line) {
   result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
 }
 #endif


        


More information about the flang-commits mailing list