[flang-commits] [flang] [flang][runtime] Support SUM/PRODUCT/DOT_PRODUCT reductions for REAL(16). (PR #83169)
via flang-commits
flang-commits at lists.llvm.org
Tue Feb 27 11:03:02 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-runtime
Author: Slava Zakharin (vzakhari)
<details>
<summary>Changes</summary>
The reductions implementations rely on trivial operations that
are supported by the build compiler runtime, so they can be enabled
whenever the build compiler provides 128-bit float support.
std::conj used by DOT_PRODUCT is a template implementation
in most environments, so it should not introduce a dependency
on any 128-bit float support library. I am not goind to
test it in all the build environments before merging.
If it fails for someone, I will deal with it.
---
Full diff: https://github.com/llvm/llvm-project/pull/83169.diff
8 Files Affected:
- (modified) flang/include/flang/Common/float128.h (+16)
- (modified) flang/include/flang/Runtime/reduction.h (+9-3)
- (modified) flang/runtime/Float128Math/cabs.cpp (+2-4)
- (modified) flang/runtime/Float128Math/math-entries.h (-9)
- (modified) flang/runtime/complex-reduction.c (+38-9)
- (modified) flang/runtime/complex-reduction.h (+10-3)
- (modified) flang/runtime/product.cpp (+4-2)
- (modified) flang/runtime/sum.cpp (+2-1)
``````````diff
diff --git a/flang/include/flang/Common/float128.h b/flang/include/flang/Common/float128.h
index 3443aa06437b04..6aa98df5529df2 100644
--- a/flang/include/flang/Common/float128.h
+++ b/flang/include/flang/Common/float128.h
@@ -49,4 +49,20 @@
#endif /* (defined(__FLOAT128__) || defined(__SIZEOF_FLOAT128__)) && \
!defined(_LIBCPP_VERSION) && !defined(__CUDA_ARCH__) */
+/* Define pure C CFloat128Type and CFloat128ComplexType. */
+#if LDBL_MANT_DIG == 113
+typedef long double CFloat128Type;
+typedef long double _Complex CFloat128ComplexType;
+#elif HAS_FLOAT128
+typedef __float128 CFloat128Type;
+/*
+ * Use mode() attribute supported by GCC and Clang.
+ * Adjust it for other compilers as needed.
+ */
+#if !defined(_ARCH_PPC) || defined(__LONG_DOUBLE_IEEE128__)
+typedef _Complex float __attribute__((mode(TC))) CFloat128ComplexType;
+#else
+typedef _Complex float __attribute__((mode(KC))) CFloat128ComplexType;
+#endif
+#endif
#endif /* FORTRAN_COMMON_FLOAT128_H_ */
diff --git a/flang/include/flang/Runtime/reduction.h b/flang/include/flang/Runtime/reduction.h
index b91fec0cd26b51..6d62f4016937e0 100644
--- a/flang/include/flang/Runtime/reduction.h
+++ b/flang/include/flang/Runtime/reduction.h
@@ -92,9 +92,11 @@ void RTDECL(CppSumComplex8)(std::complex<double> &, const Descriptor &,
void RTDECL(CppSumComplex10)(std::complex<long double> &, const Descriptor &,
const char *source, int line, int dim = 0,
const Descriptor *mask = nullptr);
-void RTDECL(CppSumComplex16)(std::complex<long double> &, const Descriptor &,
- const char *source, int line, int dim = 0,
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+void RTDECL(CppSumComplex16)(std::complex<CppFloat128Type> &,
+ const Descriptor &, const char *source, int line, int dim = 0,
const Descriptor *mask = nullptr);
+#endif
void RTDECL(SumDim)(Descriptor &result, const Descriptor &array, int dim,
const char *source, int line, const Descriptor *mask = nullptr);
@@ -145,12 +147,16 @@ void RTDECL(CppProductComplex4)(std::complex<float> &, const Descriptor &,
void RTDECL(CppProductComplex8)(std::complex<double> &, const Descriptor &,
const char *source, int line, int dim = 0,
const Descriptor *mask = nullptr);
+#if LDBL_MANT_DIG == 64
void RTDECL(CppProductComplex10)(std::complex<long double> &,
const Descriptor &, const char *source, int line, int dim = 0,
const Descriptor *mask = nullptr);
-void RTDECL(CppProductComplex16)(std::complex<long double> &,
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+void RTDECL(CppProductComplex16)(std::complex<CppFloat128Type> &,
const Descriptor &, const char *source, int line, int dim = 0,
const Descriptor *mask = nullptr);
+#endif
void RTDECL(ProductDim)(Descriptor &result, const Descriptor &array, int dim,
const char *source, int line, const Descriptor *mask = nullptr);
diff --git a/flang/runtime/Float128Math/cabs.cpp b/flang/runtime/Float128Math/cabs.cpp
index 63f2bdf8e177ae..2867c8a4578a80 100644
--- a/flang/runtime/Float128Math/cabs.cpp
+++ b/flang/runtime/Float128Math/cabs.cpp
@@ -12,10 +12,8 @@ namespace Fortran::runtime {
extern "C" {
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
-// FIXME: the argument should be CppTypeFor<TypeCategory::Complex, 16>,
-// and it should be translated into the underlying library's
-// corresponding complex128 type.
-CppTypeFor<TypeCategory::Real, 16> RTDEF(CAbsF128)(ComplexF128 x) {
+// NOTE: Flang calls the runtime APIs using C _Complex ABI
+CppTypeFor<TypeCategory::Real, 16> RTDEF(CAbsF128)(CFloat128ComplexType x) {
return CAbs<RTNAME(CAbsF128)>::invoke(x);
}
#endif
diff --git a/flang/runtime/Float128Math/math-entries.h b/flang/runtime/Float128Math/math-entries.h
index fe1525468edcaf..141648d2fb2c54 100644
--- a/flang/runtime/Float128Math/math-entries.h
+++ b/flang/runtime/Float128Math/math-entries.h
@@ -91,15 +91,6 @@ DEFINE_FALLBACK(Y0)
DEFINE_FALLBACK(Y1)
DEFINE_FALLBACK(Yn)
-// Define ComplexF128 type that is compatible with
-// the type of results/arguments of libquadmath.
-// TODO: this may need more work for other libraries/compilers.
-#if !defined(_ARCH_PPC) || defined(__LONG_DOUBLE_IEEE128__)
-typedef _Complex float __attribute__((mode(TC))) ComplexF128;
-#else
-typedef _Complex float __attribute__((mode(KC))) ComplexF128;
-#endif
-
#if HAS_LIBM
// Define wrapper callers for libm.
#include <ccomplex>
diff --git a/flang/runtime/complex-reduction.c b/flang/runtime/complex-reduction.c
index d77e1c0a550069..06e4f15c7fa9b5 100644
--- a/flang/runtime/complex-reduction.c
+++ b/flang/runtime/complex-reduction.c
@@ -19,6 +19,11 @@ struct CppComplexDouble {
struct CppComplexLongDouble {
long double r, i;
};
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+struct CppComplexFloat128 {
+ CFloat128Type r, i;
+};
+#endif
/* Not all environments define CMPLXF, CMPLX, CMPLXL. */
@@ -70,6 +75,27 @@ static long_double_Complex_t CMPLXL(long double r, long double i) {
#endif
#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+/*
+ * GCC 7.4.0 (currently minimum GCC version for llvm builds)
+ * supports __builtin_complex. For Clang, require >=12.0.
+ * Otherwise, rely on the memory layout compatibility.
+ */
+#if (defined(__clang_major__) && (__clang_major__ >= 12)) || defined(__GNUC__)
+#define CMPLXF128 __builtin_complex
+#else
+static CFloat128ComplexType CMPLXF128(CFloat128Type r, CFloat128Type i) {
+ union {
+ struct CppComplexFloat128 x;
+ CFloat128ComplexType result;
+ } u;
+ u.x.r = r;
+ u.x.i = i;
+ return u.result;
+}
+#endif
+#endif
+
/* RTNAME(SumComplex4) calls RTNAME(CppSumComplex4) with the same arguments
* and converts the members of its C++ complex result to C _Complex.
*/
@@ -93,9 +119,10 @@ ADAPT_REDUCTION(SumComplex8, double_Complex_t, CppComplexDouble, CMPLX,
#if LDBL_MANT_DIG == 64
ADAPT_REDUCTION(SumComplex10, long_double_Complex_t, CppComplexLongDouble,
CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
-#elif LDBL_MANT_DIG == 113
-ADAPT_REDUCTION(SumComplex16, long_double_Complex_t, CppComplexLongDouble,
- CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+ADAPT_REDUCTION(SumComplex16, CFloat128ComplexType, CppComplexFloat128,
+ CMPLXF128, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
#endif
/* PRODUCT() */
@@ -106,9 +133,10 @@ ADAPT_REDUCTION(ProductComplex8, double_Complex_t, CppComplexDouble, CMPLX,
#if LDBL_MANT_DIG == 64
ADAPT_REDUCTION(ProductComplex10, long_double_Complex_t, CppComplexLongDouble,
CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
-#elif LDBL_MANT_DIG == 113
-ADAPT_REDUCTION(ProductComplex16, long_double_Complex_t, CppComplexLongDouble,
- CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+ADAPT_REDUCTION(ProductComplex16, CFloat128ComplexType, CppComplexFloat128,
+ CMPLXF128, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
#endif
/* DOT_PRODUCT() */
@@ -119,7 +147,8 @@ ADAPT_REDUCTION(DotProductComplex8, double_Complex_t, CppComplexDouble, CMPLX,
#if LDBL_MANT_DIG == 64
ADAPT_REDUCTION(DotProductComplex10, long_double_Complex_t,
CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
-#elif LDBL_MANT_DIG == 113
-ADAPT_REDUCTION(DotProductComplex16, long_double_Complex_t,
- CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+ADAPT_REDUCTION(DotProductComplex16, CFloat128ComplexType, CppComplexFloat128,
+ CMPLXF128, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
#endif
diff --git a/flang/runtime/complex-reduction.h b/flang/runtime/complex-reduction.h
index 5c4f1f5126e393..1d37b235d5194b 100644
--- a/flang/runtime/complex-reduction.h
+++ b/flang/runtime/complex-reduction.h
@@ -15,6 +15,7 @@
#ifndef FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_
#define FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_
+#include "flang/Common/float128.h"
#include "flang/Runtime/entry-names.h"
#include <complex.h>
@@ -40,14 +41,18 @@ float_Complex_t RTNAME(SumComplex3)(REDUCTION_ARGS);
float_Complex_t RTNAME(SumComplex4)(REDUCTION_ARGS);
double_Complex_t RTNAME(SumComplex8)(REDUCTION_ARGS);
long_double_Complex_t RTNAME(SumComplex10)(REDUCTION_ARGS);
-long_double_Complex_t RTNAME(SumComplex16)(REDUCTION_ARGS);
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+CFloat128ComplexType RTNAME(SumComplex16)(REDUCTION_ARGS);
+#endif
float_Complex_t RTNAME(ProductComplex2)(REDUCTION_ARGS);
float_Complex_t RTNAME(ProductComplex3)(REDUCTION_ARGS);
float_Complex_t RTNAME(ProductComplex4)(REDUCTION_ARGS);
double_Complex_t RTNAME(ProductComplex8)(REDUCTION_ARGS);
long_double_Complex_t RTNAME(ProductComplex10)(REDUCTION_ARGS);
-long_double_Complex_t RTNAME(ProductComplex16)(REDUCTION_ARGS);
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+CFloat128ComplexType RTNAME(ProductComplex16)(REDUCTION_ARGS);
+#endif
#define DOT_PRODUCT_ARGS \
const struct CppDescriptor *x, const struct CppDescriptor *y, \
@@ -60,6 +65,8 @@ float_Complex_t RTNAME(DotProductComplex3)(DOT_PRODUCT_ARGS);
float_Complex_t RTNAME(DotProductComplex4)(DOT_PRODUCT_ARGS);
double_Complex_t RTNAME(DotProductComplex8)(DOT_PRODUCT_ARGS);
long_double_Complex_t RTNAME(DotProductComplex10)(DOT_PRODUCT_ARGS);
-long_double_Complex_t RTNAME(DotProductComplex16)(DOT_PRODUCT_ARGS);
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+CFloat128ComplexType RTNAME(DotProductComplex16)(DOT_PRODUCT_ARGS);
+#endif
#endif // FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_
diff --git a/flang/runtime/product.cpp b/flang/runtime/product.cpp
index a516bc51a959b7..4c3b8c33a12e0f 100644
--- a/flang/runtime/product.cpp
+++ b/flang/runtime/product.cpp
@@ -123,7 +123,8 @@ CppTypeFor<TypeCategory::Real, 10> RTDEF(ProductReal10)(const Descriptor &x,
NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 10>>{x},
"PRODUCT");
}
-#elif LDBL_MANT_DIG == 113
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
CppTypeFor<TypeCategory::Real, 16> RTDEF(ProductReal16)(const Descriptor &x,
const char *source, int line, int dim, const Descriptor *mask) {
return GetTotalReduction<TypeCategory::Real, 16>(x, source, line, dim, mask,
@@ -154,7 +155,8 @@ void RTDEF(CppProductComplex10)(CppTypeFor<TypeCategory::Complex, 10> &result,
mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 10>>{x},
"PRODUCT");
}
-#elif LDBL_MANT_DIG == 113
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
void RTDEF(CppProductComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
const Descriptor &x, const char *source, int line, int dim,
const Descriptor *mask) {
diff --git a/flang/runtime/sum.cpp b/flang/runtime/sum.cpp
index 048399737c8501..d2495e3e956fe6 100644
--- a/flang/runtime/sum.cpp
+++ b/flang/runtime/sum.cpp
@@ -175,7 +175,8 @@ void RTDEF(CppSumComplex10)(CppTypeFor<TypeCategory::Complex, 10> &result,
result = GetTotalReduction<TypeCategory::Complex, 10>(
x, source, line, dim, mask, ComplexSumAccumulator<long double>{x}, "SUM");
}
-#elif LDBL_MANT_DIG == 113
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
void RTDEF(CppSumComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
const Descriptor &x, const char *source, int line, int dim,
const Descriptor *mask) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/83169
More information about the flang-commits
mailing list