[flang-commits] [flang] 50e0b29 - [flang] Implement DOT_PRODUCT in the runtime

peter klausler via flang-commits flang-commits at lists.llvm.org
Thu May 13 10:40:15 PDT 2021


Author: peter klausler
Date: 2021-05-13T10:40:07-07:00
New Revision: 50e0b2985e43baf61617c9734df71e949113f911

URL: https://github.com/llvm/llvm-project/commit/50e0b2985e43baf61617c9734df71e949113f911
DIFF: https://github.com/llvm/llvm-project/commit/50e0b2985e43baf61617c9734df71e949113f911.diff

LOG: [flang] Implement DOT_PRODUCT in the runtime

API, implementation, and basic tests for the transformational
reduction intrinsic function DOT_PRODUCT in the runtime support
library.

Differential Revision: https://reviews.llvm.org/D102351

Added: 
    flang/runtime/dot-product.cpp

Modified: 
    flang/runtime/CMakeLists.txt
    flang/runtime/complex-reduction.c
    flang/runtime/complex-reduction.h
    flang/runtime/reduction.cpp
    flang/runtime/reduction.h
    flang/runtime/tools.h
    flang/unittests/RuntimeGTest/Reduction.cpp

Removed: 
    


################################################################################
diff  --git a/flang/runtime/CMakeLists.txt b/flang/runtime/CMakeLists.txt
index c63fd3dd5f182..84d13f12a1106 100644
--- a/flang/runtime/CMakeLists.txt
+++ b/flang/runtime/CMakeLists.txt
@@ -39,6 +39,7 @@ add_flang_library(FortranRuntime
   connection.cpp
   derived.cpp
   descriptor.cpp
+  dot-product.cpp
   edit-input.cpp
   edit-output.cpp
   environment.cpp

diff  --git a/flang/runtime/complex-reduction.c b/flang/runtime/complex-reduction.c
index 3f74eeddd74e1..d0ca50f72d652 100644
--- a/flang/runtime/complex-reduction.c
+++ b/flang/runtime/complex-reduction.c
@@ -75,34 +75,51 @@ static long_double_Complex_t CMPLXL(long double r, long double i) {
  */
 
 #define CPP_NAME(name) Cpp##name
-#define ADAPT_REDUCTION(name, cComplex, cpptype, cmplxMacro) \
-  struct cpptype RTNAME(CPP_NAME(name))(struct cpptype *, REDUCTION_ARGS); \
-  cComplex RTNAME(name)(REDUCTION_ARGS) { \
+#define ADAPT_REDUCTION(name, cComplex, cpptype, cmplxMacro, ARGS, ARG_NAMES) \
+  struct cpptype RTNAME(CPP_NAME(name))(struct cpptype *, ARGS); \
+  cComplex RTNAME(name)(ARGS) { \
     struct cpptype result; \
-    RTNAME(CPP_NAME(name))(&result, REDUCTION_ARG_NAMES); \
+    RTNAME(CPP_NAME(name))(&result, ARG_NAMES); \
     return cmplxMacro(result.r, result.i); \
   }
 
 /* TODO: COMPLEX(2 & 3) */
 
 /* SUM() */
-ADAPT_REDUCTION(SumComplex4, float_Complex_t, CppComplexFloat, CMPLXF)
-ADAPT_REDUCTION(SumComplex8, double_Complex_t, CppComplexDouble, CMPLX)
+ADAPT_REDUCTION(SumComplex4, float_Complex_t, CppComplexFloat, CMPLXF,
+    REDUCTION_ARGS, REDUCTION_ARG_NAMES)
+ADAPT_REDUCTION(SumComplex8, double_Complex_t, CppComplexDouble, CMPLX,
+    REDUCTION_ARGS, REDUCTION_ARG_NAMES)
 #if LONG_DOUBLE == 80
-ADAPT_REDUCTION(
-    SumComplex10, long_double_Complex_t, CppComplexLongDouble, CMPLXL)
+ADAPT_REDUCTION(SumComplex10, long_double_Complex_t, CppComplexLongDouble,
+    CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
 #elif LONG_DOUBLE == 128
-ADAPT_REDUCTION(
-    SumComplex16, long_double_Complex_t, CppComplexLongDouble, CMPLXL)
+ADAPT_REDUCTION(SumComplex16, long_double_Complex_t, CppComplexLongDouble,
+    CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
 #endif
 
 /* PRODUCT() */
-ADAPT_REDUCTION(ProductComplex4, float_Complex_t, CppComplexFloat, CMPLXF)
-ADAPT_REDUCTION(ProductComplex8, double_Complex_t, CppComplexDouble, CMPLX)
+ADAPT_REDUCTION(ProductComplex4, float_Complex_t, CppComplexFloat, CMPLXF,
+    REDUCTION_ARGS, REDUCTION_ARG_NAMES)
+ADAPT_REDUCTION(ProductComplex8, double_Complex_t, CppComplexDouble, CMPLX,
+    REDUCTION_ARGS, REDUCTION_ARG_NAMES)
 #if LONG_DOUBLE == 80
-ADAPT_REDUCTION(
-    ProductComplex10, long_double_Complex_t, CppComplexLongDouble, CMPLXL)
+ADAPT_REDUCTION(ProductComplex10, long_double_Complex_t, CppComplexLongDouble,
+    CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
 #elif LONG_DOUBLE == 128
-ADAPT_REDUCTION(
-    ProductComplex16, long_double_Complex_t, CppComplexLongDouble, CMPLXL)
+ADAPT_REDUCTION(ProductComplex16, long_double_Complex_t, CppComplexLongDouble,
+    CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
+#endif
+
+/* DOT_PRODUCT() */
+ADAPT_REDUCTION(DotProductComplex4, float_Complex_t, CppComplexFloat, CMPLXF,
+    DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
+ADAPT_REDUCTION(DotProductComplex8, double_Complex_t, CppComplexDouble, CMPLX,
+    DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
+#if LONG_DOUBLE == 80
+ADAPT_REDUCTION(DotProductComplex10, long_double_Complex_t,
+    CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
+#elif LONG_DOUBLE == 128
+ADAPT_REDUCTION(DotProductComplex16, long_double_Complex_t,
+    CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
 #endif

diff  --git a/flang/runtime/complex-reduction.h b/flang/runtime/complex-reduction.h
index 562b9523ea797..f26847f2ded5b 100644
--- a/flang/runtime/complex-reduction.h
+++ b/flang/runtime/complex-reduction.h
@@ -49,4 +49,17 @@ double_Complex_t RTNAME(ProductComplex8)(REDUCTION_ARGS);
 long_double_Complex_t RTNAME(ProductComplex10)(REDUCTION_ARGS);
 long_double_Complex_t RTNAME(ProductComplex16)(REDUCTION_ARGS);
 
+#define DOT_PRODUCT_ARGS \
+  const struct CppDescriptor *x, const struct CppDescriptor *y, \
+      const char *source, int line, int dim /*=0*/, \
+      const struct CppDescriptor *mask /*=NULL*/
+#define DOT_PRODUCT_ARG_NAMES x, y, source, line, dim, mask
+
+float_Complex_t RTNAME(DotProductComplex2)(DOT_PRODUCT_ARGS);
+float_Complex_t RTNAME(DotProductComplex3)(DOT_PRODUCT_ARGS);
+float_Complex_t RTNAME(DotProductComplex4)(DOT_PRODUCT_ARGS);
+double_Complex_t RTNAME(DotProductComplex8)(DOT_PRODUCT_ARGS);
+long_double_Complex_t RTNAME(DotProductComplex10)(DOT_PRODUCT_ARGS);
+long_double_Complex_t RTNAME(DotProductComplex16)(DOT_PRODUCT_ARGS);
+
 #endif // FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_

diff  --git a/flang/runtime/dot-product.cpp b/flang/runtime/dot-product.cpp
new file mode 100644
index 0000000000000..1c83d8de3bf3c
--- /dev/null
+++ b/flang/runtime/dot-product.cpp
@@ -0,0 +1,199 @@
+//===-- runtime/dot-product.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "cpp-type.h"
+#include "descriptor.h"
+#include "reduction.h"
+#include "terminator.h"
+#include "tools.h"
+#include <cinttypes>
+
+namespace Fortran::runtime {
+
+template <typename ACCUMULATOR>
+static inline auto DoDotProduct(const Descriptor &x, const Descriptor &y,
+    Terminator &terminator) -> typename ACCUMULATOR::Result {
+  RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
+  SubscriptValue n{x.GetDimension(0).Extent()};
+  if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
+    terminator.Crash(
+        "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
+        static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
+  }
+  SubscriptValue xAt{x.GetDimension(0).LowerBound()};
+  SubscriptValue yAt{y.GetDimension(0).LowerBound()};
+  ACCUMULATOR accumulator{x, y};
+  for (SubscriptValue j{0}; j < n; ++j) {
+    accumulator.Accumulate(xAt++, yAt++);
+  }
+  return accumulator.GetResult();
+}
+
+template <TypeCategory RCAT, int RKIND,
+    template <typename, TypeCategory, typename, typename> class ACCUM>
+struct DotProduct {
+  using Result = CppTypeFor<RCAT, RKIND>;
+  template <TypeCategory XCAT, int XKIND> struct DP1 {
+    template <TypeCategory YCAT, int YKIND> struct DP2 {
+      Result operator()(const Descriptor &x, const Descriptor &y,
+          Terminator &terminator) const {
+        if constexpr (constexpr auto resultType{
+                          GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
+          if constexpr (resultType->first == RCAT &&
+              resultType->second <= RKIND) {
+            using Accum = ACCUM<Result, XCAT, CppTypeFor<XCAT, XKIND>,
+                CppTypeFor<YCAT, YKIND>>;
+            return DoDotProduct<Accum>(x, y, terminator);
+          }
+        }
+        terminator.Crash(
+            "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
+            static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
+            static_cast<int>(YCAT), YKIND);
+      }
+    };
+    Result operator()(const Descriptor &x, const Descriptor &y,
+        Terminator &terminator, TypeCategory yCat, int yKind) const {
+      return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
+    }
+  };
+  Result operator()(const Descriptor &x, const Descriptor &y,
+      const char *source, int line) const {
+    Terminator terminator{source, line};
+    auto xCatKind{x.type().GetCategoryAndKind()};
+    auto yCatKind{y.type().GetCategoryAndKind()};
+    RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
+    return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, terminator,
+        x, y, terminator, yCatKind->first, yCatKind->second);
+  }
+};
+
+template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
+class NumericAccumulator {
+public:
+  using Result = RESULT;
+  NumericAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
+  void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
+    if constexpr (XCAT == TypeCategory::Complex) {
+      sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) *
+          static_cast<Result>(*y_.Element<YT>(&yAt));
+    } else {
+      sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) *
+          static_cast<Result>(*y_.Element<YT>(&yAt));
+    }
+  }
+  Result GetResult() const { return sum_; }
+
+private:
+  const Descriptor &x_, &y_;
+  Result sum_{0};
+};
+
+template <typename, TypeCategory, typename XT, typename YT>
+class LogicalAccumulator {
+public:
+  using Result = bool;
+  LogicalAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
+  void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
+    result_ = result_ ||
+        (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
+  }
+  bool GetResult() const { return result_; }
+
+private:
+  const Descriptor &x_, &y_;
+  bool result_{false};
+};
+
+extern "C" {
+std::int8_t RTNAME(DotProductInteger1)(
+    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+  return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
+      x, y, source, line);
+}
+std::int16_t RTNAME(DotProductInteger2)(
+    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+  return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
+      x, y, source, line);
+}
+std::int32_t RTNAME(DotProductInteger4)(
+    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+  return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
+      x, y, source, line);
+}
+std::int64_t RTNAME(DotProductInteger8)(
+    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+  return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
+      x, y, source, line);
+}
+#ifdef __SIZEOF_INT128__
+common::int128_t RTNAME(DotProductInteger16)(
+    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+  return DotProduct<TypeCategory::Integer, 16, NumericAccumulator>{}(
+      x, y, source, line);
+}
+#endif
+
+// TODO: REAL/COMPLEX(2 & 3)
+float RTNAME(DotProductReal4)(
+    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+  return DotProduct<TypeCategory::Real, 8, NumericAccumulator>{}(
+      x, y, source, line);
+}
+double RTNAME(DotProductReal8)(
+    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+  return DotProduct<TypeCategory::Real, 8, NumericAccumulator>{}(
+      x, y, source, line);
+}
+#if LONG_DOUBLE == 80
+long double RTNAME(DotProductReal10)(
+    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+  return DotProduct<TypeCategory::Real, 10, NumericAccumulator>{}(
+      x, y, source, line);
+}
+#elif LONG_DOUBLE == 128
+long double RTNAME(DotProductReal16)(
+    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+  return DotProduct<TypeCategory::Real, 16, NumericAccumulator>{}(
+      x, y, source, line);
+}
+#endif
+
+void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
+    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+  auto z{DotProduct<TypeCategory::Complex, 8, NumericAccumulator>{}(
+      x, y, source, line)};
+  result = std::complex<float>{
+      static_cast<float>(z.real()), static_cast<float>(z.imag())};
+}
+void RTNAME(CppDotProductComplex8)(std::complex<double> &result,
+    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+  result = DotProduct<TypeCategory::Complex, 8, NumericAccumulator>{}(
+      x, y, source, line);
+}
+#if LONG_DOUBLE == 80
+void RTNAME(CppDotProductComplex10)(std::complex<long double> &result,
+    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+  result = DotProduct<TypeCategory::Complex, 10, NumericAccumulator>{}(
+      x, y, source, line);
+}
+#elif LONG_DOUBLE == 128
+void RTNAME(CppDotProductComplex16)(std::complex<long double> &result,
+    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+  result = DotProduct<TypeCategory::Complex, 16, NumericAccumulator>{}(
+      x, y, source, line);
+}
+#endif
+
+bool RTNAME(DotProductLogical)(
+    const Descriptor &x, const Descriptor &y, const char *source, int line) {
+  return DotProduct<TypeCategory::Logical, 1, LogicalAccumulator>{}(
+      x, y, source, line);
+}
+} // extern "C"
+} // namespace Fortran::runtime

diff  --git a/flang/runtime/reduction.cpp b/flang/runtime/reduction.cpp
index 15964cfee4662..cf9515b7ad43a 100644
--- a/flang/runtime/reduction.cpp
+++ b/flang/runtime/reduction.cpp
@@ -9,8 +9,8 @@
 // Implements ALL, ANY, COUNT, IPARITY, & PARITY for all required operand
 // types and shapes.
 //
-// FINDLOC, SUM, and PRODUCT are in their own eponymous source files;
-// NORM2, MAXLOC, MINLOC, MAXVAL, and MINVAL are in extrema.cpp.
+// DOT_PRODUCT, FINDLOC, SUM, and PRODUCT are in their own eponymous source
+// files; NORM2, MAXLOC, MINLOC, MAXVAL, and MINVAL are in extrema.cpp.
 
 #include "reduction.h"
 #include "reduction-templates.h"

diff  --git a/flang/runtime/reduction.h b/flang/runtime/reduction.h
index f6d29edf04454..cec30843da6d5 100644
--- a/flang/runtime/reduction.h
+++ b/flang/runtime/reduction.h
@@ -7,9 +7,9 @@
 //===----------------------------------------------------------------------===//
 
 // Defines the API for the reduction transformational intrinsic functions.
-// (Except the complex-valued total reduction forms of SUM and PRODUCT;
-// the API for those is in complex-reduction.h so that C's _Complex can
-// be used for their return types.)
+// (Except the complex-valued DOT_PRODUCT and the complex-valued total reduction
+// forms of SUM & PRODUCT; the API for those is in complex-reduction.h so that
+// C's _Complex can be used for their return types.)
 
 #ifndef FORTRAN_RUNTIME_REDUCTION_H_
 #define FORTRAN_RUNTIME_REDUCTION_H_
@@ -275,6 +275,48 @@ bool RTNAME(Parity)(
 void RTNAME(ParityDim)(Descriptor &result, const Descriptor &, int dim,
     const char *source, int line);
 
+// DOT_PRODUCT
+std::int8_t RTNAME(DotProductInteger1)(const Descriptor &, const Descriptor &,
+    const char *source = nullptr, int line = 0);
+std::int16_t RTNAME(DotProductInteger2)(const Descriptor &, const Descriptor &,
+    const char *source = nullptr, int line = 0);
+std::int32_t RTNAME(DotProductInteger4)(const Descriptor &, const Descriptor &,
+    const char *source = nullptr, int line = 0);
+std::int64_t RTNAME(DotProductInteger8)(const Descriptor &, const Descriptor &,
+    const char *source = nullptr, int line = 0);
+#ifdef __SIZEOF_INT128__
+common::int128_t RTNAME(DotProductInteger16)(const Descriptor &,
+    const Descriptor &, const char *source = nullptr, int line = 0);
+#endif
+float RTNAME(DotProductReal2)(const Descriptor &, const Descriptor &,
+    const char *source = nullptr, int line = 0);
+float RTNAME(DotProductReal3)(const Descriptor &, const Descriptor &,
+    const char *source = nullptr, int line = 0);
+float RTNAME(DotProductReal4)(const Descriptor &, const Descriptor &,
+    const char *source = nullptr, int line = 0);
+double RTNAME(DotProductReal8)(const Descriptor &, const Descriptor &,
+    const char *source = nullptr, int line = 0);
+long double RTNAME(DotProductReal10)(const Descriptor &, const Descriptor &,
+    const char *source = nullptr, int line = 0);
+long double RTNAME(DotProductReal16)(const Descriptor &, const Descriptor &,
+    const char *source = nullptr, int line = 0);
+void RTNAME(CppDotProductComplex2)(std::complex<float> &, const Descriptor &,
+    const Descriptor &, const char *source = nullptr, int line = 0);
+void RTNAME(CppDotProductComplex3)(std::complex<float> &, const Descriptor &,
+    const Descriptor &, const char *source = nullptr, int line = 0);
+void RTNAME(CppDotProductComplex4)(std::complex<float> &, const Descriptor &,
+    const Descriptor &, const char *source = nullptr, int line = 0);
+void RTNAME(CppDotProductComplex8)(std::complex<double> &, const Descriptor &,
+    const Descriptor &, const char *source = nullptr, int line = 0);
+void RTNAME(CppDotProductComplex10)(std::complex<long double> &,
+    const Descriptor &, const Descriptor &, const char *source = nullptr,
+    int line = 0);
+void RTNAME(CppDotProductComplex16)(std::complex<long double> &,
+    const Descriptor &, const Descriptor &, const char *source = nullptr,
+    int line = 0);
+bool RTNAME(DotProductLogical)(const Descriptor &, const Descriptor &,
+    const char *source = nullptr, int line = 0);
+
 } // extern "C"
 } // namespace Fortran::runtime
 #endif // FORTRAN_RUNTIME_REDUCTION_H_

diff  --git a/flang/runtime/tools.h b/flang/runtime/tools.h
index c5996dc3e5684..ee8c439b6cb55 100644
--- a/flang/runtime/tools.h
+++ b/flang/runtime/tools.h
@@ -102,6 +102,104 @@ inline bool SetInteger(INT &x, int kind, std::int64_t value) {
   }
 }
 
+// Maps intrinsic runtime type category and kind values to the appropriate
+// instantiation of a function object template and calls it with the supplied
+// arguments.
+template <template <TypeCategory, int> class FUNC, typename RESULT,
+    typename... A>
+inline RESULT ApplyType(
+    TypeCategory cat, int kind, Terminator &terminator, A &&...x) {
+  switch (cat) {
+  case TypeCategory::Integer:
+    switch (kind) {
+    case 1:
+      return FUNC<TypeCategory::Integer, 1>{}(std::forward<A>(x)...);
+    case 2:
+      return FUNC<TypeCategory::Integer, 2>{}(std::forward<A>(x)...);
+    case 4:
+      return FUNC<TypeCategory::Integer, 4>{}(std::forward<A>(x)...);
+    case 8:
+      return FUNC<TypeCategory::Integer, 8>{}(std::forward<A>(x)...);
+#ifdef __SIZEOF_INT128__
+    case 16:
+      return FUNC<TypeCategory::Integer, 16>{}(std::forward<A>(x)...);
+#endif
+    default:
+      terminator.Crash("unsupported INTEGER(KIND=%d)", kind);
+    }
+  case TypeCategory::Real:
+    switch (kind) {
+#if 0 // TODO: REAL(2 & 3)
+    case 2:
+      return FUNC<TypeCategory::Real, 2>{}(std::forward<A>(x)...);
+    case 3:
+      return FUNC<TypeCategory::Real, 3>{}(std::forward<A>(x)...);
+#endif
+    case 4:
+      return FUNC<TypeCategory::Real, 4>{}(std::forward<A>(x)...);
+    case 8:
+      return FUNC<TypeCategory::Real, 8>{}(std::forward<A>(x)...);
+#if LONG_DOUBLE == 80
+    case 10:
+      return FUNC<TypeCategory::Real, 10>{}(std::forward<A>(x)...);
+#elif LONG_DOUBLE == 128
+    case 16:
+      return FUNC<TypeCategory::Real, 16>{}(std::forward<A>(x)...);
+#endif
+    default:
+      terminator.Crash("unsupported REAL(KIND=%d)", kind);
+    }
+  case TypeCategory::Complex:
+    switch (kind) {
+#if 0 // TODO: COMPLEX(2 & 3)
+    case 2:
+      return FUNC<TypeCategory::Complex, 2>{}(std::forward<A>(x)...);
+    case 3:
+      return FUNC<TypeCategory::Complex, 3>{}(std::forward<A>(x)...);
+#endif
+    case 4:
+      return FUNC<TypeCategory::Complex, 4>{}(std::forward<A>(x)...);
+    case 8:
+      return FUNC<TypeCategory::Complex, 8>{}(std::forward<A>(x)...);
+#if LONG_DOUBLE == 80
+    case 10:
+      return FUNC<TypeCategory::Complex, 10>{}(std::forward<A>(x)...);
+#elif LONG_DOUBLE == 128
+    case 16:
+      return FUNC<TypeCategory::Complex, 16>{}(std::forward<A>(x)...);
+#endif
+    default:
+      terminator.Crash("unsupported COMPLEX(KIND=%d)", kind);
+    }
+  case TypeCategory::Character:
+    switch (kind) {
+    case 1:
+      return FUNC<TypeCategory::Character, 1>{}(std::forward<A>(x)...);
+    case 2:
+      return FUNC<TypeCategory::Character, 2>{}(std::forward<A>(x)...);
+    case 4:
+      return FUNC<TypeCategory::Character, 4>{}(std::forward<A>(x)...);
+    default:
+      terminator.Crash("unsupported CHARACTER(KIND=%d)", kind);
+    }
+  case TypeCategory::Logical:
+    switch (kind) {
+    case 1:
+      return FUNC<TypeCategory::Logical, 1>{}(std::forward<A>(x)...);
+    case 2:
+      return FUNC<TypeCategory::Logical, 2>{}(std::forward<A>(x)...);
+    case 4:
+      return FUNC<TypeCategory::Logical, 4>{}(std::forward<A>(x)...);
+    case 8:
+      return FUNC<TypeCategory::Logical, 8>{}(std::forward<A>(x)...);
+    default:
+      terminator.Crash("unsupported LOGICAL(KIND=%d)", kind);
+    }
+  default:
+    terminator.Crash("unsupported type category(%d)", static_cast<int>(cat));
+  }
+}
+
 // Maps a runtime INTEGER kind value to the appropriate instantiation of
 // a function object template and calls it with the supplied arguments.
 template <template <int KIND> class FUNC, typename RESULT, typename... A>
@@ -180,5 +278,61 @@ inline RESULT ApplyLogicalKind(int kind, Terminator &terminator, A &&...x) {
   }
 }
 
+// Calculate result type of (X op Y) for *, //, DOT_PRODUCT, &c.
+std::optional<std::pair<TypeCategory, int>> inline constexpr GetResultType(
+    TypeCategory xCat, int xKind, TypeCategory yCat, int yKind) {
+  int maxKind{std::max(xKind, yKind)};
+  switch (xCat) {
+  case TypeCategory::Integer:
+    switch (yCat) {
+    case TypeCategory::Integer:
+      return std::make_pair(TypeCategory::Integer, maxKind);
+    case TypeCategory::Real:
+    case TypeCategory::Complex:
+      return std::make_pair(yCat, yKind);
+    default:
+      break;
+    }
+    break;
+  case TypeCategory::Real:
+    switch (yCat) {
+    case TypeCategory::Integer:
+      return std::make_pair(TypeCategory::Real, xKind);
+    case TypeCategory::Real:
+    case TypeCategory::Complex:
+      return std::make_pair(yCat, maxKind);
+    default:
+      break;
+    }
+    break;
+  case TypeCategory::Complex:
+    switch (yCat) {
+    case TypeCategory::Integer:
+      return std::make_pair(TypeCategory::Complex, xKind);
+    case TypeCategory::Real:
+    case TypeCategory::Complex:
+      return std::make_pair(TypeCategory::Complex, maxKind);
+    default:
+      break;
+    }
+    break;
+  case TypeCategory::Character:
+    if (yCat == TypeCategory::Character) {
+      return std::make_pair(TypeCategory::Character, maxKind);
+    } else {
+      return std::nullopt;
+    }
+  case TypeCategory::Logical:
+    if (yCat == TypeCategory::Logical) {
+      return std::make_pair(TypeCategory::Logical, maxKind);
+    } else {
+      return std::nullopt;
+    }
+  default:
+    break;
+  }
+  return std::nullopt;
+}
+
 } // namespace Fortran::runtime
 #endif // FORTRAN_RUNTIME_TOOLS_H_

diff  --git a/flang/unittests/RuntimeGTest/Reduction.cpp b/flang/unittests/RuntimeGTest/Reduction.cpp
index 18f0b02cbfb5c..5a2c6fb80b379 100644
--- a/flang/unittests/RuntimeGTest/Reduction.cpp
+++ b/flang/unittests/RuntimeGTest/Reduction.cpp
@@ -431,3 +431,37 @@ TEST(Reductions, FindlocNumeric) {
   EXPECT_EQ(*res.ZeroBasedIndexedElement<SubscriptValue>(1), 0);
   res.Destroy();
 }
+
+TEST(Reductions, DotProduct) {
+  auto realVector{MakeArray<TypeCategory::Real, 8>(
+      std::vector<int>{4}, std::vector<double>{0.0, -0.0, 1.0, -2.0})};
+  EXPECT_EQ(
+      RTNAME(DotProductReal8)(*realVector, *realVector, __FILE__, __LINE__),
+      5.0);
+  auto complexVector{MakeArray<TypeCategory::Complex, 4>(std::vector<int>{4},
+      std::vector<std::complex<float>>{
+          {0.0}, {-0.0, -0.0}, {1.0, -2.0}, {-2.0, 4.0}})};
+  std::complex<double> result8;
+  RTNAME(CppDotProductComplex8)
+  (result8, *realVector, *complexVector, __FILE__, __LINE__);
+  EXPECT_EQ(result8, (std::complex<double>{5.0, -10.0}));
+  RTNAME(CppDotProductComplex8)
+  (result8, *complexVector, *realVector, __FILE__, __LINE__);
+  EXPECT_EQ(result8, (std::complex<double>{5.0, 10.0}));
+  std::complex<float> result4;
+  RTNAME(CppDotProductComplex4)
+  (result4, *complexVector, *complexVector, __FILE__, __LINE__);
+  EXPECT_EQ(result4, (std::complex<float>{25.0, 0.0}));
+  auto logicalVector1{MakeArray<TypeCategory::Logical, 1>(
+      std::vector<int>{4}, std::vector<bool>{false, false, true, true})};
+  EXPECT_TRUE(RTNAME(DotProductLogical)(
+      *logicalVector1, *logicalVector1, __FILE__, __LINE__));
+  auto logicalVector2{MakeArray<TypeCategory::Logical, 1>(
+      std::vector<int>{4}, std::vector<bool>{true, true, false, false})};
+  EXPECT_TRUE(RTNAME(DotProductLogical)(
+      *logicalVector2, *logicalVector2, __FILE__, __LINE__));
+  EXPECT_FALSE(RTNAME(DotProductLogical)(
+      *logicalVector1, *logicalVector2, __FILE__, __LINE__));
+  EXPECT_FALSE(RTNAME(DotProductLogical)(
+      *logicalVector2, *logicalVector1, __FILE__, __LINE__));
+}


        


More information about the flang-commits mailing list