[flang-commits] [flang] 50e0b29 - [flang] Implement DOT_PRODUCT in the runtime
peter klausler via flang-commits
flang-commits at lists.llvm.org
Thu May 13 10:40:15 PDT 2021
Author: peter klausler
Date: 2021-05-13T10:40:07-07:00
New Revision: 50e0b2985e43baf61617c9734df71e949113f911
URL: https://github.com/llvm/llvm-project/commit/50e0b2985e43baf61617c9734df71e949113f911
DIFF: https://github.com/llvm/llvm-project/commit/50e0b2985e43baf61617c9734df71e949113f911.diff
LOG: [flang] Implement DOT_PRODUCT in the runtime
API, implementation, and basic tests for the transformational
reduction intrinsic function DOT_PRODUCT in the runtime support
library.
Differential Revision: https://reviews.llvm.org/D102351
Added:
flang/runtime/dot-product.cpp
Modified:
flang/runtime/CMakeLists.txt
flang/runtime/complex-reduction.c
flang/runtime/complex-reduction.h
flang/runtime/reduction.cpp
flang/runtime/reduction.h
flang/runtime/tools.h
flang/unittests/RuntimeGTest/Reduction.cpp
Removed:
################################################################################
diff --git a/flang/runtime/CMakeLists.txt b/flang/runtime/CMakeLists.txt
index c63fd3dd5f182..84d13f12a1106 100644
--- a/flang/runtime/CMakeLists.txt
+++ b/flang/runtime/CMakeLists.txt
@@ -39,6 +39,7 @@ add_flang_library(FortranRuntime
connection.cpp
derived.cpp
descriptor.cpp
+ dot-product.cpp
edit-input.cpp
edit-output.cpp
environment.cpp
diff --git a/flang/runtime/complex-reduction.c b/flang/runtime/complex-reduction.c
index 3f74eeddd74e1..d0ca50f72d652 100644
--- a/flang/runtime/complex-reduction.c
+++ b/flang/runtime/complex-reduction.c
@@ -75,34 +75,51 @@ static long_double_Complex_t CMPLXL(long double r, long double i) {
*/
#define CPP_NAME(name) Cpp##name
-#define ADAPT_REDUCTION(name, cComplex, cpptype, cmplxMacro) \
- struct cpptype RTNAME(CPP_NAME(name))(struct cpptype *, REDUCTION_ARGS); \
- cComplex RTNAME(name)(REDUCTION_ARGS) { \
+#define ADAPT_REDUCTION(name, cComplex, cpptype, cmplxMacro, ARGS, ARG_NAMES) \
+ struct cpptype RTNAME(CPP_NAME(name))(struct cpptype *, ARGS); \
+ cComplex RTNAME(name)(ARGS) { \
struct cpptype result; \
- RTNAME(CPP_NAME(name))(&result, REDUCTION_ARG_NAMES); \
+ RTNAME(CPP_NAME(name))(&result, ARG_NAMES); \
return cmplxMacro(result.r, result.i); \
}
/* TODO: COMPLEX(2 & 3) */
/* SUM() */
-ADAPT_REDUCTION(SumComplex4, float_Complex_t, CppComplexFloat, CMPLXF)
-ADAPT_REDUCTION(SumComplex8, double_Complex_t, CppComplexDouble, CMPLX)
+ADAPT_REDUCTION(SumComplex4, float_Complex_t, CppComplexFloat, CMPLXF,
+ REDUCTION_ARGS, REDUCTION_ARG_NAMES)
+ADAPT_REDUCTION(SumComplex8, double_Complex_t, CppComplexDouble, CMPLX,
+ REDUCTION_ARGS, REDUCTION_ARG_NAMES)
#if LONG_DOUBLE == 80
-ADAPT_REDUCTION(
- SumComplex10, long_double_Complex_t, CppComplexLongDouble, CMPLXL)
+ADAPT_REDUCTION(SumComplex10, long_double_Complex_t, CppComplexLongDouble,
+ CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
#elif LONG_DOUBLE == 128
-ADAPT_REDUCTION(
- SumComplex16, long_double_Complex_t, CppComplexLongDouble, CMPLXL)
+ADAPT_REDUCTION(SumComplex16, long_double_Complex_t, CppComplexLongDouble,
+ CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
#endif
/* PRODUCT() */
-ADAPT_REDUCTION(ProductComplex4, float_Complex_t, CppComplexFloat, CMPLXF)
-ADAPT_REDUCTION(ProductComplex8, double_Complex_t, CppComplexDouble, CMPLX)
+ADAPT_REDUCTION(ProductComplex4, float_Complex_t, CppComplexFloat, CMPLXF,
+ REDUCTION_ARGS, REDUCTION_ARG_NAMES)
+ADAPT_REDUCTION(ProductComplex8, double_Complex_t, CppComplexDouble, CMPLX,
+ REDUCTION_ARGS, REDUCTION_ARG_NAMES)
#if LONG_DOUBLE == 80
-ADAPT_REDUCTION(
- ProductComplex10, long_double_Complex_t, CppComplexLongDouble, CMPLXL)
+ADAPT_REDUCTION(ProductComplex10, long_double_Complex_t, CppComplexLongDouble,
+ CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
#elif LONG_DOUBLE == 128
-ADAPT_REDUCTION(
- ProductComplex16, long_double_Complex_t, CppComplexLongDouble, CMPLXL)
+ADAPT_REDUCTION(ProductComplex16, long_double_Complex_t, CppComplexLongDouble,
+ CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
+#endif
+
+/* DOT_PRODUCT() */
+ADAPT_REDUCTION(DotProductComplex4, float_Complex_t, CppComplexFloat, CMPLXF,
+ DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
+ADAPT_REDUCTION(DotProductComplex8, double_Complex_t, CppComplexDouble, CMPLX,
+ DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
+#if LONG_DOUBLE == 80
+ADAPT_REDUCTION(DotProductComplex10, long_double_Complex_t,
+ CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
+#elif LONG_DOUBLE == 128
+ADAPT_REDUCTION(DotProductComplex16, long_double_Complex_t,
+ CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
#endif
diff --git a/flang/runtime/complex-reduction.h b/flang/runtime/complex-reduction.h
index 562b9523ea797..f26847f2ded5b 100644
--- a/flang/runtime/complex-reduction.h
+++ b/flang/runtime/complex-reduction.h
@@ -49,4 +49,17 @@ double_Complex_t RTNAME(ProductComplex8)(REDUCTION_ARGS);
long_double_Complex_t RTNAME(ProductComplex10)(REDUCTION_ARGS);
long_double_Complex_t RTNAME(ProductComplex16)(REDUCTION_ARGS);
+#define DOT_PRODUCT_ARGS \
+ const struct CppDescriptor *x, const struct CppDescriptor *y, \
+ const char *source, int line, int dim /*=0*/, \
+ const struct CppDescriptor *mask /*=NULL*/
+#define DOT_PRODUCT_ARG_NAMES x, y, source, line, dim, mask
+
+float_Complex_t RTNAME(DotProductComplex2)(DOT_PRODUCT_ARGS);
+float_Complex_t RTNAME(DotProductComplex3)(DOT_PRODUCT_ARGS);
+float_Complex_t RTNAME(DotProductComplex4)(DOT_PRODUCT_ARGS);
+double_Complex_t RTNAME(DotProductComplex8)(DOT_PRODUCT_ARGS);
+long_double_Complex_t RTNAME(DotProductComplex10)(DOT_PRODUCT_ARGS);
+long_double_Complex_t RTNAME(DotProductComplex16)(DOT_PRODUCT_ARGS);
+
#endif // FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_
diff --git a/flang/runtime/dot-product.cpp b/flang/runtime/dot-product.cpp
new file mode 100644
index 0000000000000..1c83d8de3bf3c
--- /dev/null
+++ b/flang/runtime/dot-product.cpp
@@ -0,0 +1,199 @@
+//===-- runtime/dot-product.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "cpp-type.h"
+#include "descriptor.h"
+#include "reduction.h"
+#include "terminator.h"
+#include "tools.h"
+#include <cinttypes>
+
+namespace Fortran::runtime {
+
+template <typename ACCUMULATOR>
+static inline auto DoDotProduct(const Descriptor &x, const Descriptor &y,
+ Terminator &terminator) -> typename ACCUMULATOR::Result {
+ RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
+ SubscriptValue n{x.GetDimension(0).Extent()};
+ if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
+ terminator.Crash(
+ "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
+ static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
+ }
+ SubscriptValue xAt{x.GetDimension(0).LowerBound()};
+ SubscriptValue yAt{y.GetDimension(0).LowerBound()};
+ ACCUMULATOR accumulator{x, y};
+ for (SubscriptValue j{0}; j < n; ++j) {
+ accumulator.Accumulate(xAt++, yAt++);
+ }
+ return accumulator.GetResult();
+}
+
+template <TypeCategory RCAT, int RKIND,
+ template <typename, TypeCategory, typename, typename> class ACCUM>
+struct DotProduct {
+ using Result = CppTypeFor<RCAT, RKIND>;
+ template <TypeCategory XCAT, int XKIND> struct DP1 {
+ template <TypeCategory YCAT, int YKIND> struct DP2 {
+ Result operator()(const Descriptor &x, const Descriptor &y,
+ Terminator &terminator) const {
+ if constexpr (constexpr auto resultType{
+ GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
+ if constexpr (resultType->first == RCAT &&
+ resultType->second <= RKIND) {
+ using Accum = ACCUM<Result, XCAT, CppTypeFor<XCAT, XKIND>,
+ CppTypeFor<YCAT, YKIND>>;
+ return DoDotProduct<Accum>(x, y, terminator);
+ }
+ }
+ terminator.Crash(
+ "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
+ static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
+ static_cast<int>(YCAT), YKIND);
+ }
+ };
+ Result operator()(const Descriptor &x, const Descriptor &y,
+ Terminator &terminator, TypeCategory yCat, int yKind) const {
+ return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
+ }
+ };
+ Result operator()(const Descriptor &x, const Descriptor &y,
+ const char *source, int line) const {
+ Terminator terminator{source, line};
+ auto xCatKind{x.type().GetCategoryAndKind()};
+ auto yCatKind{y.type().GetCategoryAndKind()};
+ RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
+ return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, terminator,
+ x, y, terminator, yCatKind->first, yCatKind->second);
+ }
+};
+
+template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
+class NumericAccumulator {
+public:
+ using Result = RESULT;
+ NumericAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
+ void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
+ if constexpr (XCAT == TypeCategory::Complex) {
+ sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) *
+ static_cast<Result>(*y_.Element<YT>(&yAt));
+ } else {
+ sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) *
+ static_cast<Result>(*y_.Element<YT>(&yAt));
+ }
+ }
+ Result GetResult() const { return sum_; }
+
+private:
+ const Descriptor &x_, &y_;
+ Result sum_{0};
+};
+
+template <typename, TypeCategory, typename XT, typename YT>
+class LogicalAccumulator {
+public:
+ using Result = bool;
+ LogicalAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
+ void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
+ result_ = result_ ||
+ (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
+ }
+ bool GetResult() const { return result_; }
+
+private:
+ const Descriptor &x_, &y_;
+ bool result_{false};
+};
+
+extern "C" {
+std::int8_t RTNAME(DotProductInteger1)(
+ const Descriptor &x, const Descriptor &y, const char *source, int line) {
+ return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
+ x, y, source, line);
+}
+std::int16_t RTNAME(DotProductInteger2)(
+ const Descriptor &x, const Descriptor &y, const char *source, int line) {
+ return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
+ x, y, source, line);
+}
+std::int32_t RTNAME(DotProductInteger4)(
+ const Descriptor &x, const Descriptor &y, const char *source, int line) {
+ return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
+ x, y, source, line);
+}
+std::int64_t RTNAME(DotProductInteger8)(
+ const Descriptor &x, const Descriptor &y, const char *source, int line) {
+ return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
+ x, y, source, line);
+}
+#ifdef __SIZEOF_INT128__
+common::int128_t RTNAME(DotProductInteger16)(
+ const Descriptor &x, const Descriptor &y, const char *source, int line) {
+ return DotProduct<TypeCategory::Integer, 16, NumericAccumulator>{}(
+ x, y, source, line);
+}
+#endif
+
+// TODO: REAL/COMPLEX(2 & 3)
+float RTNAME(DotProductReal4)(
+ const Descriptor &x, const Descriptor &y, const char *source, int line) {
+ return DotProduct<TypeCategory::Real, 8, NumericAccumulator>{}(
+ x, y, source, line);
+}
+double RTNAME(DotProductReal8)(
+ const Descriptor &x, const Descriptor &y, const char *source, int line) {
+ return DotProduct<TypeCategory::Real, 8, NumericAccumulator>{}(
+ x, y, source, line);
+}
+#if LONG_DOUBLE == 80
+long double RTNAME(DotProductReal10)(
+ const Descriptor &x, const Descriptor &y, const char *source, int line) {
+ return DotProduct<TypeCategory::Real, 10, NumericAccumulator>{}(
+ x, y, source, line);
+}
+#elif LONG_DOUBLE == 128
+long double RTNAME(DotProductReal16)(
+ const Descriptor &x, const Descriptor &y, const char *source, int line) {
+ return DotProduct<TypeCategory::Real, 16, NumericAccumulator>{}(
+ x, y, source, line);
+}
+#endif
+
+void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
+ const Descriptor &x, const Descriptor &y, const char *source, int line) {
+ auto z{DotProduct<TypeCategory::Complex, 8, NumericAccumulator>{}(
+ x, y, source, line)};
+ result = std::complex<float>{
+ static_cast<float>(z.real()), static_cast<float>(z.imag())};
+}
+void RTNAME(CppDotProductComplex8)(std::complex<double> &result,
+ const Descriptor &x, const Descriptor &y, const char *source, int line) {
+ result = DotProduct<TypeCategory::Complex, 8, NumericAccumulator>{}(
+ x, y, source, line);
+}
+#if LONG_DOUBLE == 80
+void RTNAME(CppDotProductComplex10)(std::complex<long double> &result,
+ const Descriptor &x, const Descriptor &y, const char *source, int line) {
+ result = DotProduct<TypeCategory::Complex, 10, NumericAccumulator>{}(
+ x, y, source, line);
+}
+#elif LONG_DOUBLE == 128
+void RTNAME(CppDotProductComplex16)(std::complex<long double> &result,
+ const Descriptor &x, const Descriptor &y, const char *source, int line) {
+ result = DotProduct<TypeCategory::Complex, 16, NumericAccumulator>{}(
+ x, y, source, line);
+}
+#endif
+
+bool RTNAME(DotProductLogical)(
+ const Descriptor &x, const Descriptor &y, const char *source, int line) {
+ return DotProduct<TypeCategory::Logical, 1, LogicalAccumulator>{}(
+ x, y, source, line);
+}
+} // extern "C"
+} // namespace Fortran::runtime
diff --git a/flang/runtime/reduction.cpp b/flang/runtime/reduction.cpp
index 15964cfee4662..cf9515b7ad43a 100644
--- a/flang/runtime/reduction.cpp
+++ b/flang/runtime/reduction.cpp
@@ -9,8 +9,8 @@
// Implements ALL, ANY, COUNT, IPARITY, & PARITY for all required operand
// types and shapes.
//
-// FINDLOC, SUM, and PRODUCT are in their own eponymous source files;
-// NORM2, MAXLOC, MINLOC, MAXVAL, and MINVAL are in extrema.cpp.
+// DOT_PRODUCT, FINDLOC, SUM, and PRODUCT are in their own eponymous source
+// files; NORM2, MAXLOC, MINLOC, MAXVAL, and MINVAL are in extrema.cpp.
#include "reduction.h"
#include "reduction-templates.h"
diff --git a/flang/runtime/reduction.h b/flang/runtime/reduction.h
index f6d29edf04454..cec30843da6d5 100644
--- a/flang/runtime/reduction.h
+++ b/flang/runtime/reduction.h
@@ -7,9 +7,9 @@
//===----------------------------------------------------------------------===//
// Defines the API for the reduction transformational intrinsic functions.
-// (Except the complex-valued total reduction forms of SUM and PRODUCT;
-// the API for those is in complex-reduction.h so that C's _Complex can
-// be used for their return types.)
+// (Except the complex-valued DOT_PRODUCT and the complex-valued total reduction
+// forms of SUM & PRODUCT; the API for those is in complex-reduction.h so that
+// C's _Complex can be used for their return types.)
#ifndef FORTRAN_RUNTIME_REDUCTION_H_
#define FORTRAN_RUNTIME_REDUCTION_H_
@@ -275,6 +275,48 @@ bool RTNAME(Parity)(
void RTNAME(ParityDim)(Descriptor &result, const Descriptor &, int dim,
const char *source, int line);
+// DOT_PRODUCT
+std::int8_t RTNAME(DotProductInteger1)(const Descriptor &, const Descriptor &,
+ const char *source = nullptr, int line = 0);
+std::int16_t RTNAME(DotProductInteger2)(const Descriptor &, const Descriptor &,
+ const char *source = nullptr, int line = 0);
+std::int32_t RTNAME(DotProductInteger4)(const Descriptor &, const Descriptor &,
+ const char *source = nullptr, int line = 0);
+std::int64_t RTNAME(DotProductInteger8)(const Descriptor &, const Descriptor &,
+ const char *source = nullptr, int line = 0);
+#ifdef __SIZEOF_INT128__
+common::int128_t RTNAME(DotProductInteger16)(const Descriptor &,
+ const Descriptor &, const char *source = nullptr, int line = 0);
+#endif
+float RTNAME(DotProductReal2)(const Descriptor &, const Descriptor &,
+ const char *source = nullptr, int line = 0);
+float RTNAME(DotProductReal3)(const Descriptor &, const Descriptor &,
+ const char *source = nullptr, int line = 0);
+float RTNAME(DotProductReal4)(const Descriptor &, const Descriptor &,
+ const char *source = nullptr, int line = 0);
+double RTNAME(DotProductReal8)(const Descriptor &, const Descriptor &,
+ const char *source = nullptr, int line = 0);
+long double RTNAME(DotProductReal10)(const Descriptor &, const Descriptor &,
+ const char *source = nullptr, int line = 0);
+long double RTNAME(DotProductReal16)(const Descriptor &, const Descriptor &,
+ const char *source = nullptr, int line = 0);
+void RTNAME(CppDotProductComplex2)(std::complex<float> &, const Descriptor &,
+ const Descriptor &, const char *source = nullptr, int line = 0);
+void RTNAME(CppDotProductComplex3)(std::complex<float> &, const Descriptor &,
+ const Descriptor &, const char *source = nullptr, int line = 0);
+void RTNAME(CppDotProductComplex4)(std::complex<float> &, const Descriptor &,
+ const Descriptor &, const char *source = nullptr, int line = 0);
+void RTNAME(CppDotProductComplex8)(std::complex<double> &, const Descriptor &,
+ const Descriptor &, const char *source = nullptr, int line = 0);
+void RTNAME(CppDotProductComplex10)(std::complex<long double> &,
+ const Descriptor &, const Descriptor &, const char *source = nullptr,
+ int line = 0);
+void RTNAME(CppDotProductComplex16)(std::complex<long double> &,
+ const Descriptor &, const Descriptor &, const char *source = nullptr,
+ int line = 0);
+bool RTNAME(DotProductLogical)(const Descriptor &, const Descriptor &,
+ const char *source = nullptr, int line = 0);
+
} // extern "C"
} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_REDUCTION_H_
diff --git a/flang/runtime/tools.h b/flang/runtime/tools.h
index c5996dc3e5684..ee8c439b6cb55 100644
--- a/flang/runtime/tools.h
+++ b/flang/runtime/tools.h
@@ -102,6 +102,104 @@ inline bool SetInteger(INT &x, int kind, std::int64_t value) {
}
}
+// Maps intrinsic runtime type category and kind values to the appropriate
+// instantiation of a function object template and calls it with the supplied
+// arguments.
+template <template <TypeCategory, int> class FUNC, typename RESULT,
+ typename... A>
+inline RESULT ApplyType(
+ TypeCategory cat, int kind, Terminator &terminator, A &&...x) {
+ switch (cat) {
+ case TypeCategory::Integer:
+ switch (kind) {
+ case 1:
+ return FUNC<TypeCategory::Integer, 1>{}(std::forward<A>(x)...);
+ case 2:
+ return FUNC<TypeCategory::Integer, 2>{}(std::forward<A>(x)...);
+ case 4:
+ return FUNC<TypeCategory::Integer, 4>{}(std::forward<A>(x)...);
+ case 8:
+ return FUNC<TypeCategory::Integer, 8>{}(std::forward<A>(x)...);
+#ifdef __SIZEOF_INT128__
+ case 16:
+ return FUNC<TypeCategory::Integer, 16>{}(std::forward<A>(x)...);
+#endif
+ default:
+ terminator.Crash("unsupported INTEGER(KIND=%d)", kind);
+ }
+ case TypeCategory::Real:
+ switch (kind) {
+#if 0 // TODO: REAL(2 & 3)
+ case 2:
+ return FUNC<TypeCategory::Real, 2>{}(std::forward<A>(x)...);
+ case 3:
+ return FUNC<TypeCategory::Real, 3>{}(std::forward<A>(x)...);
+#endif
+ case 4:
+ return FUNC<TypeCategory::Real, 4>{}(std::forward<A>(x)...);
+ case 8:
+ return FUNC<TypeCategory::Real, 8>{}(std::forward<A>(x)...);
+#if LONG_DOUBLE == 80
+ case 10:
+ return FUNC<TypeCategory::Real, 10>{}(std::forward<A>(x)...);
+#elif LONG_DOUBLE == 128
+ case 16:
+ return FUNC<TypeCategory::Real, 16>{}(std::forward<A>(x)...);
+#endif
+ default:
+ terminator.Crash("unsupported REAL(KIND=%d)", kind);
+ }
+ case TypeCategory::Complex:
+ switch (kind) {
+#if 0 // TODO: COMPLEX(2 & 3)
+ case 2:
+ return FUNC<TypeCategory::Complex, 2>{}(std::forward<A>(x)...);
+ case 3:
+ return FUNC<TypeCategory::Complex, 3>{}(std::forward<A>(x)...);
+#endif
+ case 4:
+ return FUNC<TypeCategory::Complex, 4>{}(std::forward<A>(x)...);
+ case 8:
+ return FUNC<TypeCategory::Complex, 8>{}(std::forward<A>(x)...);
+#if LONG_DOUBLE == 80
+ case 10:
+ return FUNC<TypeCategory::Complex, 10>{}(std::forward<A>(x)...);
+#elif LONG_DOUBLE == 128
+ case 16:
+ return FUNC<TypeCategory::Complex, 16>{}(std::forward<A>(x)...);
+#endif
+ default:
+ terminator.Crash("unsupported COMPLEX(KIND=%d)", kind);
+ }
+ case TypeCategory::Character:
+ switch (kind) {
+ case 1:
+ return FUNC<TypeCategory::Character, 1>{}(std::forward<A>(x)...);
+ case 2:
+ return FUNC<TypeCategory::Character, 2>{}(std::forward<A>(x)...);
+ case 4:
+ return FUNC<TypeCategory::Character, 4>{}(std::forward<A>(x)...);
+ default:
+ terminator.Crash("unsupported CHARACTER(KIND=%d)", kind);
+ }
+ case TypeCategory::Logical:
+ switch (kind) {
+ case 1:
+ return FUNC<TypeCategory::Logical, 1>{}(std::forward<A>(x)...);
+ case 2:
+ return FUNC<TypeCategory::Logical, 2>{}(std::forward<A>(x)...);
+ case 4:
+ return FUNC<TypeCategory::Logical, 4>{}(std::forward<A>(x)...);
+ case 8:
+ return FUNC<TypeCategory::Logical, 8>{}(std::forward<A>(x)...);
+ default:
+ terminator.Crash("unsupported LOGICAL(KIND=%d)", kind);
+ }
+ default:
+ terminator.Crash("unsupported type category(%d)", static_cast<int>(cat));
+ }
+}
+
// Maps a runtime INTEGER kind value to the appropriate instantiation of
// a function object template and calls it with the supplied arguments.
template <template <int KIND> class FUNC, typename RESULT, typename... A>
@@ -180,5 +278,61 @@ inline RESULT ApplyLogicalKind(int kind, Terminator &terminator, A &&...x) {
}
}
+// Calculate result type of (X op Y) for *, //, DOT_PRODUCT, &c.
+std::optional<std::pair<TypeCategory, int>> inline constexpr GetResultType(
+ TypeCategory xCat, int xKind, TypeCategory yCat, int yKind) {
+ int maxKind{std::max(xKind, yKind)};
+ switch (xCat) {
+ case TypeCategory::Integer:
+ switch (yCat) {
+ case TypeCategory::Integer:
+ return std::make_pair(TypeCategory::Integer, maxKind);
+ case TypeCategory::Real:
+ case TypeCategory::Complex:
+ return std::make_pair(yCat, yKind);
+ default:
+ break;
+ }
+ break;
+ case TypeCategory::Real:
+ switch (yCat) {
+ case TypeCategory::Integer:
+ return std::make_pair(TypeCategory::Real, xKind);
+ case TypeCategory::Real:
+ case TypeCategory::Complex:
+ return std::make_pair(yCat, maxKind);
+ default:
+ break;
+ }
+ break;
+ case TypeCategory::Complex:
+ switch (yCat) {
+ case TypeCategory::Integer:
+ return std::make_pair(TypeCategory::Complex, xKind);
+ case TypeCategory::Real:
+ case TypeCategory::Complex:
+ return std::make_pair(TypeCategory::Complex, maxKind);
+ default:
+ break;
+ }
+ break;
+ case TypeCategory::Character:
+ if (yCat == TypeCategory::Character) {
+ return std::make_pair(TypeCategory::Character, maxKind);
+ } else {
+ return std::nullopt;
+ }
+ case TypeCategory::Logical:
+ if (yCat == TypeCategory::Logical) {
+ return std::make_pair(TypeCategory::Logical, maxKind);
+ } else {
+ return std::nullopt;
+ }
+ default:
+ break;
+ }
+ return std::nullopt;
+}
+
} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_TOOLS_H_
diff --git a/flang/unittests/RuntimeGTest/Reduction.cpp b/flang/unittests/RuntimeGTest/Reduction.cpp
index 18f0b02cbfb5c..5a2c6fb80b379 100644
--- a/flang/unittests/RuntimeGTest/Reduction.cpp
+++ b/flang/unittests/RuntimeGTest/Reduction.cpp
@@ -431,3 +431,37 @@ TEST(Reductions, FindlocNumeric) {
EXPECT_EQ(*res.ZeroBasedIndexedElement<SubscriptValue>(1), 0);
res.Destroy();
}
+
+TEST(Reductions, DotProduct) {
+ auto realVector{MakeArray<TypeCategory::Real, 8>(
+ std::vector<int>{4}, std::vector<double>{0.0, -0.0, 1.0, -2.0})};
+ EXPECT_EQ(
+ RTNAME(DotProductReal8)(*realVector, *realVector, __FILE__, __LINE__),
+ 5.0);
+ auto complexVector{MakeArray<TypeCategory::Complex, 4>(std::vector<int>{4},
+ std::vector<std::complex<float>>{
+ {0.0}, {-0.0, -0.0}, {1.0, -2.0}, {-2.0, 4.0}})};
+ std::complex<double> result8;
+ RTNAME(CppDotProductComplex8)
+ (result8, *realVector, *complexVector, __FILE__, __LINE__);
+ EXPECT_EQ(result8, (std::complex<double>{5.0, -10.0}));
+ RTNAME(CppDotProductComplex8)
+ (result8, *complexVector, *realVector, __FILE__, __LINE__);
+ EXPECT_EQ(result8, (std::complex<double>{5.0, 10.0}));
+ std::complex<float> result4;
+ RTNAME(CppDotProductComplex4)
+ (result4, *complexVector, *complexVector, __FILE__, __LINE__);
+ EXPECT_EQ(result4, (std::complex<float>{25.0, 0.0}));
+ auto logicalVector1{MakeArray<TypeCategory::Logical, 1>(
+ std::vector<int>{4}, std::vector<bool>{false, false, true, true})};
+ EXPECT_TRUE(RTNAME(DotProductLogical)(
+ *logicalVector1, *logicalVector1, __FILE__, __LINE__));
+ auto logicalVector2{MakeArray<TypeCategory::Logical, 1>(
+ std::vector<int>{4}, std::vector<bool>{true, true, false, false})};
+ EXPECT_TRUE(RTNAME(DotProductLogical)(
+ *logicalVector2, *logicalVector2, __FILE__, __LINE__));
+ EXPECT_FALSE(RTNAME(DotProductLogical)(
+ *logicalVector1, *logicalVector2, __FILE__, __LINE__));
+ EXPECT_FALSE(RTNAME(DotProductLogical)(
+ *logicalVector2, *logicalVector1, __FILE__, __LINE__));
+}
More information about the flang-commits
mailing list