[flang-commits] [flang] f8a9f43 - [flang][runtime] Enable real/complex kind 10 and 16 variants of dot_product.
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Wed Aug 31 15:19:57 PDT 2022
Author: Slava Zakharin
Date: 2022-08-31T15:17:17-07:00
New Revision: f8a9f43ef7affb7991e60cdd5ce93d2566f5b2e4
URL: https://github.com/llvm/llvm-project/commit/f8a9f43ef7affb7991e60cdd5ce93d2566f5b2e4
DIFF: https://github.com/llvm/llvm-project/commit/f8a9f43ef7affb7991e60cdd5ce93d2566f5b2e4.diff
LOG: [flang][runtime] Enable real/complex kind 10 and 16 variants of dot_product.
HasCppTypeFor<> used to evaluate to false always, so kind 10 and 16
variants of dot_product were not instantiated even though the host
supported 80- and 128-bit real and complex data types.
In addition, HAS_FLOAT128 was not enabling complex kind=16 variant
of dot_product. This is fixed now.
Note that the change for HasCppTypeFor<> may also affect other
functions such as matmul, i.e. kind 10 and 16 variants of them
may become available now (depending on the build host).
Differential Revision: https://reviews.llvm.org/D133051
Added:
Modified:
flang/include/flang/Runtime/cpp-type.h
flang/runtime/dot-product.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Runtime/cpp-type.h b/flang/include/flang/Runtime/cpp-type.h
index aa4b6f360124e..00af2c115484e 100644
--- a/flang/include/flang/Runtime/cpp-type.h
+++ b/flang/include/flang/Runtime/cpp-type.h
@@ -23,14 +23,14 @@ namespace Fortran::runtime {
using common::TypeCategory;
-template <TypeCategory CAT, int KIND> struct CppTypeForHelper {};
+template <TypeCategory CAT, int KIND> struct CppTypeForHelper {
+ using type = void;
+};
template <TypeCategory CAT, int KIND>
using CppTypeFor = typename CppTypeForHelper<CAT, KIND>::type;
-template <TypeCategory CAT, int KIND, bool SFINAE = false>
-constexpr bool HasCppTypeFor{false};
template <TypeCategory CAT, int KIND>
-constexpr bool HasCppTypeFor<CAT, KIND, true>{
+constexpr bool HasCppTypeFor{
!std::is_void_v<typename CppTypeForHelper<CAT, KIND>::type>};
template <int KIND> struct CppTypeForHelper<TypeCategory::Integer, KIND> {
diff --git a/flang/runtime/dot-product.cpp b/flang/runtime/dot-product.cpp
index 2f9debbfccaa0..857ed6759817a 100644
--- a/flang/runtime/dot-product.cpp
+++ b/flang/runtime/dot-product.cpp
@@ -147,24 +147,24 @@ template <TypeCategory RCAT, int RKIND> struct DotProduct {
};
extern "C" {
-std::int8_t RTNAME(DotProductInteger1)(
+CppTypeFor<TypeCategory::Integer, 1> RTNAME(DotProductInteger1)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
}
-std::int16_t RTNAME(DotProductInteger2)(
+CppTypeFor<TypeCategory::Integer, 2> RTNAME(DotProductInteger2)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
}
-std::int32_t RTNAME(DotProductInteger4)(
+CppTypeFor<TypeCategory::Integer, 4> RTNAME(DotProductInteger4)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
}
-std::int64_t RTNAME(DotProductInteger8)(
+CppTypeFor<TypeCategory::Integer, 8> RTNAME(DotProductInteger8)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
}
#ifdef __SIZEOF_INT128__
-common::int128_t RTNAME(DotProductInteger16)(
+CppTypeFor<TypeCategory::Integer, 16> RTNAME(DotProductInteger16)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
}
@@ -172,16 +172,16 @@ common::int128_t RTNAME(DotProductInteger16)(
// TODO: REAL/COMPLEX(2 & 3)
// Intermediate results and operations are at least 64 bits
-float RTNAME(DotProductReal4)(
+CppTypeFor<TypeCategory::Real, 4> RTNAME(DotProductReal4)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
}
-double RTNAME(DotProductReal8)(
+CppTypeFor<TypeCategory::Real, 8> RTNAME(DotProductReal8)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
}
#if LDBL_MANT_DIG == 64
-long double RTNAME(DotProductReal10)(
+CppTypeFor<TypeCategory::Real, 10> RTNAME(DotProductReal10)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
}
@@ -193,24 +193,25 @@ CppTypeFor<TypeCategory::Real, 16> RTNAME(DotProductReal16)(
}
#endif
-void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
+void RTNAME(CppDotProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- auto z{DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line)};
- result = std::complex<float>{
- static_cast<float>(z.real()), static_cast<float>(z.imag())};
+ result = DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line);
}
-void RTNAME(CppDotProductComplex8)(std::complex<double> &result,
+void RTNAME(CppDotProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
}
#if LDBL_MANT_DIG == 64
-void RTNAME(CppDotProductComplex10)(std::complex<long double> &result,
- const Descriptor &x, const Descriptor &y, const char *source, int line) {
+void RTNAME(CppDotProductComplex10)(
+ CppTypeFor<TypeCategory::Complex, 10> &result, const Descriptor &x,
+ const Descriptor &y, const char *source, int line) {
result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
}
-#elif LDBL_MANT_DIG == 113
-void RTNAME(CppDotProductComplex16)(std::complex<CppFloat128Type> &result,
- const Descriptor &x, const Descriptor &y, const char *source, int line) {
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+void RTNAME(CppDotProductComplex16)(
+ CppTypeFor<TypeCategory::Complex, 16> &result, const Descriptor &x,
+ const Descriptor &y, const char *source, int line) {
result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
}
#endif
More information about the flang-commits
mailing list