[flang-commits] [flang] d699d9d - [flang][runtime] Support SUM/PRODUCT/DOT_PRODUCT reductions for REAL(16). (#83169)

via flang-commits flang-commits at lists.llvm.org
Tue Feb 27 15:59:29 PST 2024


Author: Slava Zakharin
Date: 2024-02-27T15:59:25-08:00
New Revision: d699d9d609a24d80809df15efe47ac539da90e93

URL: https://github.com/llvm/llvm-project/commit/d699d9d609a24d80809df15efe47ac539da90e93
DIFF: https://github.com/llvm/llvm-project/commit/d699d9d609a24d80809df15efe47ac539da90e93.diff

LOG: [flang][runtime] Support SUM/PRODUCT/DOT_PRODUCT reductions for REAL(16). (#83169)

The reductions implementations rely on trivial operations that
are supported by the build compiler runtime, so they can be enabled
whenever the build compiler provides 128-bit float support.

std::conj used by DOT_PRODUCT is a template implementation
in most environments, so it should not introduce a dependency
on any 128-bit float support library. I am not goind to
test it in all the build environments before merging.
If it fails for someone, I will deal with it.

Added: 
    

Modified: 
    flang/include/flang/Common/float128.h
    flang/include/flang/Runtime/reduction.h
    flang/runtime/Float128Math/cabs.cpp
    flang/runtime/Float128Math/math-entries.h
    flang/runtime/complex-reduction.c
    flang/runtime/complex-reduction.h
    flang/runtime/product.cpp
    flang/runtime/sum.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Common/float128.h b/flang/include/flang/Common/float128.h
index 3443aa06437b04..6aa98df5529df2 100644
--- a/flang/include/flang/Common/float128.h
+++ b/flang/include/flang/Common/float128.h
@@ -49,4 +49,20 @@
 #endif /* (defined(__FLOAT128__) || defined(__SIZEOF_FLOAT128__)) && \
           !defined(_LIBCPP_VERSION)  && !defined(__CUDA_ARCH__) */
 
+/* Define pure C CFloat128Type and CFloat128ComplexType. */
+#if LDBL_MANT_DIG == 113
+typedef long double CFloat128Type;
+typedef long double _Complex CFloat128ComplexType;
+#elif HAS_FLOAT128
+typedef __float128 CFloat128Type;
+/*
+ * Use mode() attribute supported by GCC and Clang.
+ * Adjust it for other compilers as needed.
+ */
+#if !defined(_ARCH_PPC) || defined(__LONG_DOUBLE_IEEE128__)
+typedef _Complex float __attribute__((mode(TC))) CFloat128ComplexType;
+#else
+typedef _Complex float __attribute__((mode(KC))) CFloat128ComplexType;
+#endif
+#endif
 #endif /* FORTRAN_COMMON_FLOAT128_H_ */

diff  --git a/flang/include/flang/Runtime/reduction.h b/flang/include/flang/Runtime/reduction.h
index b91fec0cd26b51..6d62f4016937e0 100644
--- a/flang/include/flang/Runtime/reduction.h
+++ b/flang/include/flang/Runtime/reduction.h
@@ -92,9 +92,11 @@ void RTDECL(CppSumComplex8)(std::complex<double> &, const Descriptor &,
 void RTDECL(CppSumComplex10)(std::complex<long double> &, const Descriptor &,
     const char *source, int line, int dim = 0,
     const Descriptor *mask = nullptr);
-void RTDECL(CppSumComplex16)(std::complex<long double> &, const Descriptor &,
-    const char *source, int line, int dim = 0,
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+void RTDECL(CppSumComplex16)(std::complex<CppFloat128Type> &,
+    const Descriptor &, const char *source, int line, int dim = 0,
     const Descriptor *mask = nullptr);
+#endif
 
 void RTDECL(SumDim)(Descriptor &result, const Descriptor &array, int dim,
     const char *source, int line, const Descriptor *mask = nullptr);
@@ -145,12 +147,16 @@ void RTDECL(CppProductComplex4)(std::complex<float> &, const Descriptor &,
 void RTDECL(CppProductComplex8)(std::complex<double> &, const Descriptor &,
     const char *source, int line, int dim = 0,
     const Descriptor *mask = nullptr);
+#if LDBL_MANT_DIG == 64
 void RTDECL(CppProductComplex10)(std::complex<long double> &,
     const Descriptor &, const char *source, int line, int dim = 0,
     const Descriptor *mask = nullptr);
-void RTDECL(CppProductComplex16)(std::complex<long double> &,
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+void RTDECL(CppProductComplex16)(std::complex<CppFloat128Type> &,
     const Descriptor &, const char *source, int line, int dim = 0,
     const Descriptor *mask = nullptr);
+#endif
 
 void RTDECL(ProductDim)(Descriptor &result, const Descriptor &array, int dim,
     const char *source, int line, const Descriptor *mask = nullptr);

diff  --git a/flang/runtime/Float128Math/cabs.cpp b/flang/runtime/Float128Math/cabs.cpp
index 63f2bdf8e177ae..2867c8a4578a80 100644
--- a/flang/runtime/Float128Math/cabs.cpp
+++ b/flang/runtime/Float128Math/cabs.cpp
@@ -12,10 +12,8 @@ namespace Fortran::runtime {
 extern "C" {
 
 #if LDBL_MANT_DIG == 113 || HAS_FLOAT128
-// FIXME: the argument should be CppTypeFor<TypeCategory::Complex, 16>,
-// and it should be translated into the underlying library's
-// corresponding complex128 type.
-CppTypeFor<TypeCategory::Real, 16> RTDEF(CAbsF128)(ComplexF128 x) {
+// NOTE: Flang calls the runtime APIs using C _Complex ABI
+CppTypeFor<TypeCategory::Real, 16> RTDEF(CAbsF128)(CFloat128ComplexType x) {
   return CAbs<RTNAME(CAbsF128)>::invoke(x);
 }
 #endif

diff  --git a/flang/runtime/Float128Math/math-entries.h b/flang/runtime/Float128Math/math-entries.h
index fe1525468edcaf..141648d2fb2c54 100644
--- a/flang/runtime/Float128Math/math-entries.h
+++ b/flang/runtime/Float128Math/math-entries.h
@@ -91,15 +91,6 @@ DEFINE_FALLBACK(Y0)
 DEFINE_FALLBACK(Y1)
 DEFINE_FALLBACK(Yn)
 
-// Define ComplexF128 type that is compatible with
-// the type of results/arguments of libquadmath.
-// TODO: this may need more work for other libraries/compilers.
-#if !defined(_ARCH_PPC) || defined(__LONG_DOUBLE_IEEE128__)
-typedef _Complex float __attribute__((mode(TC))) ComplexF128;
-#else
-typedef _Complex float __attribute__((mode(KC))) ComplexF128;
-#endif
-
 #if HAS_LIBM
 // Define wrapper callers for libm.
 #include <ccomplex>

diff  --git a/flang/runtime/complex-reduction.c b/flang/runtime/complex-reduction.c
index d77e1c0a550069..06e4f15c7fa9b5 100644
--- a/flang/runtime/complex-reduction.c
+++ b/flang/runtime/complex-reduction.c
@@ -19,6 +19,11 @@ struct CppComplexDouble {
 struct CppComplexLongDouble {
   long double r, i;
 };
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+struct CppComplexFloat128 {
+  CFloat128Type r, i;
+};
+#endif
 
 /* Not all environments define CMPLXF, CMPLX, CMPLXL. */
 
@@ -70,6 +75,27 @@ static long_double_Complex_t CMPLXL(long double r, long double i) {
 #endif
 #endif
 
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+/*
+ * GCC 7.4.0 (currently minimum GCC version for llvm builds)
+ * supports __builtin_complex. For Clang, require >=12.0.
+ * Otherwise, rely on the memory layout compatibility.
+ */
+#if (defined(__clang_major__) && (__clang_major__ >= 12)) || defined(__GNUC__)
+#define CMPLXF128 __builtin_complex
+#else
+static CFloat128ComplexType CMPLXF128(CFloat128Type r, CFloat128Type i) {
+  union {
+    struct CppComplexFloat128 x;
+    CFloat128ComplexType result;
+  } u;
+  u.x.r = r;
+  u.x.i = i;
+  return u.result;
+}
+#endif
+#endif
+
 /* RTNAME(SumComplex4) calls RTNAME(CppSumComplex4) with the same arguments
  * and converts the members of its C++ complex result to C _Complex.
  */
@@ -93,9 +119,10 @@ ADAPT_REDUCTION(SumComplex8, double_Complex_t, CppComplexDouble, CMPLX,
 #if LDBL_MANT_DIG == 64
 ADAPT_REDUCTION(SumComplex10, long_double_Complex_t, CppComplexLongDouble,
     CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
-#elif LDBL_MANT_DIG == 113
-ADAPT_REDUCTION(SumComplex16, long_double_Complex_t, CppComplexLongDouble,
-    CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+ADAPT_REDUCTION(SumComplex16, CFloat128ComplexType, CppComplexFloat128,
+    CMPLXF128, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
 #endif
 
 /* PRODUCT() */
@@ -106,9 +133,10 @@ ADAPT_REDUCTION(ProductComplex8, double_Complex_t, CppComplexDouble, CMPLX,
 #if LDBL_MANT_DIG == 64
 ADAPT_REDUCTION(ProductComplex10, long_double_Complex_t, CppComplexLongDouble,
     CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
-#elif LDBL_MANT_DIG == 113
-ADAPT_REDUCTION(ProductComplex16, long_double_Complex_t, CppComplexLongDouble,
-    CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+ADAPT_REDUCTION(ProductComplex16, CFloat128ComplexType, CppComplexFloat128,
+    CMPLXF128, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
 #endif
 
 /* DOT_PRODUCT() */
@@ -119,7 +147,8 @@ ADAPT_REDUCTION(DotProductComplex8, double_Complex_t, CppComplexDouble, CMPLX,
 #if LDBL_MANT_DIG == 64
 ADAPT_REDUCTION(DotProductComplex10, long_double_Complex_t,
     CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
-#elif LDBL_MANT_DIG == 113
-ADAPT_REDUCTION(DotProductComplex16, long_double_Complex_t,
-    CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+ADAPT_REDUCTION(DotProductComplex16, CFloat128ComplexType, CppComplexFloat128,
+    CMPLXF128, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
 #endif

diff  --git a/flang/runtime/complex-reduction.h b/flang/runtime/complex-reduction.h
index 5c4f1f5126e393..1d37b235d5194b 100644
--- a/flang/runtime/complex-reduction.h
+++ b/flang/runtime/complex-reduction.h
@@ -15,6 +15,7 @@
 #ifndef FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_
 #define FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_
 
+#include "flang/Common/float128.h"
 #include "flang/Runtime/entry-names.h"
 #include <complex.h>
 
@@ -40,14 +41,18 @@ float_Complex_t RTNAME(SumComplex3)(REDUCTION_ARGS);
 float_Complex_t RTNAME(SumComplex4)(REDUCTION_ARGS);
 double_Complex_t RTNAME(SumComplex8)(REDUCTION_ARGS);
 long_double_Complex_t RTNAME(SumComplex10)(REDUCTION_ARGS);
-long_double_Complex_t RTNAME(SumComplex16)(REDUCTION_ARGS);
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+CFloat128ComplexType RTNAME(SumComplex16)(REDUCTION_ARGS);
+#endif
 
 float_Complex_t RTNAME(ProductComplex2)(REDUCTION_ARGS);
 float_Complex_t RTNAME(ProductComplex3)(REDUCTION_ARGS);
 float_Complex_t RTNAME(ProductComplex4)(REDUCTION_ARGS);
 double_Complex_t RTNAME(ProductComplex8)(REDUCTION_ARGS);
 long_double_Complex_t RTNAME(ProductComplex10)(REDUCTION_ARGS);
-long_double_Complex_t RTNAME(ProductComplex16)(REDUCTION_ARGS);
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+CFloat128ComplexType RTNAME(ProductComplex16)(REDUCTION_ARGS);
+#endif
 
 #define DOT_PRODUCT_ARGS \
   const struct CppDescriptor *x, const struct CppDescriptor *y, \
@@ -60,6 +65,8 @@ float_Complex_t RTNAME(DotProductComplex3)(DOT_PRODUCT_ARGS);
 float_Complex_t RTNAME(DotProductComplex4)(DOT_PRODUCT_ARGS);
 double_Complex_t RTNAME(DotProductComplex8)(DOT_PRODUCT_ARGS);
 long_double_Complex_t RTNAME(DotProductComplex10)(DOT_PRODUCT_ARGS);
-long_double_Complex_t RTNAME(DotProductComplex16)(DOT_PRODUCT_ARGS);
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+CFloat128ComplexType RTNAME(DotProductComplex16)(DOT_PRODUCT_ARGS);
+#endif
 
 #endif // FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_

diff  --git a/flang/runtime/product.cpp b/flang/runtime/product.cpp
index a516bc51a959b7..4c3b8c33a12e0f 100644
--- a/flang/runtime/product.cpp
+++ b/flang/runtime/product.cpp
@@ -123,7 +123,8 @@ CppTypeFor<TypeCategory::Real, 10> RTDEF(ProductReal10)(const Descriptor &x,
       NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 10>>{x},
       "PRODUCT");
 }
-#elif LDBL_MANT_DIG == 113
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
 CppTypeFor<TypeCategory::Real, 16> RTDEF(ProductReal16)(const Descriptor &x,
     const char *source, int line, int dim, const Descriptor *mask) {
   return GetTotalReduction<TypeCategory::Real, 16>(x, source, line, dim, mask,
@@ -154,7 +155,8 @@ void RTDEF(CppProductComplex10)(CppTypeFor<TypeCategory::Complex, 10> &result,
       mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 10>>{x},
       "PRODUCT");
 }
-#elif LDBL_MANT_DIG == 113
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
 void RTDEF(CppProductComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
     const Descriptor &x, const char *source, int line, int dim,
     const Descriptor *mask) {

diff  --git a/flang/runtime/sum.cpp b/flang/runtime/sum.cpp
index 048399737c8501..d2495e3e956fe6 100644
--- a/flang/runtime/sum.cpp
+++ b/flang/runtime/sum.cpp
@@ -175,7 +175,8 @@ void RTDEF(CppSumComplex10)(CppTypeFor<TypeCategory::Complex, 10> &result,
   result = GetTotalReduction<TypeCategory::Complex, 10>(
       x, source, line, dim, mask, ComplexSumAccumulator<long double>{x}, "SUM");
 }
-#elif LDBL_MANT_DIG == 113
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
 void RTDEF(CppSumComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
     const Descriptor &x, const char *source, int line, int dim,
     const Descriptor *mask) {


        


More information about the flang-commits mailing list