[flang-commits] [flang] a5a493e - [flang] Speed common runtime cases of DOT_PRODUCT & MATMUL

peter klausler via flang-commits flang-commits at lists.llvm.org
Fri Oct 22 14:41:00 PDT 2021


Author: Peter Klausler
Date: 2021-10-22T14:36:13-07:00
New Revision: a5a493e1920572ca78b77fcdaef68c26e96d25e7

URL: https://github.com/llvm/llvm-project/commit/a5a493e1920572ca78b77fcdaef68c26e96d25e7
DIFF: https://github.com/llvm/llvm-project/commit/a5a493e1920572ca78b77fcdaef68c26e96d25e7.diff

LOG: [flang] Speed common runtime cases of DOT_PRODUCT & MATMUL

Look for contiguous numeric argument arrays at runtime and
use specialized code for them.

Differential Revision: https://reviews.llvm.org/D112239

Added: 
    

Modified: 
    flang/include/flang/Runtime/c-or-cpp.h
    flang/include/flang/Runtime/descriptor.h
    flang/runtime/dot-product.cpp
    flang/runtime/matmul.cpp
    flang/runtime/tools.h

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Runtime/c-or-cpp.h b/flang/include/flang/Runtime/c-or-cpp.h
index 4babd885cad32..8bac523907750 100644
--- a/flang/include/flang/Runtime/c-or-cpp.h
+++ b/flang/include/flang/Runtime/c-or-cpp.h
@@ -13,11 +13,13 @@
 #define IF_CPLUSPLUS(x) x
 #define IF_NOT_CPLUSPLUS(x)
 #define DEFAULT_VALUE(x) = (x)
+#define RESTRICT __restrict
 #else
 #include <stdbool.h>
 #define IF_CPLUSPLUS(x)
 #define IF_NOT_CPLUSPLUS(x) x
 #define DEFAULT_VALUE(x)
+#define RESTRICT restrict
 #endif
 
 #define FORTRAN_EXTERN_C_BEGIN IF_CPLUSPLUS(extern "C" {)

diff  --git a/flang/include/flang/Runtime/descriptor.h b/flang/include/flang/Runtime/descriptor.h
index 2b927df3bcd29..75c5e2176d929 100644
--- a/flang/include/flang/Runtime/descriptor.h
+++ b/flang/include/flang/Runtime/descriptor.h
@@ -304,7 +304,10 @@ class Descriptor {
 
   bool IsContiguous(int leadingDimensions = maxRank) const {
     auto bytes{static_cast<SubscriptValue>(ElementBytes())};
-    for (int j{0}; j < leadingDimensions && j < raw_.rank; ++j) {
+    if (leadingDimensions > raw_.rank) {
+      leadingDimensions = raw_.rank;
+    }
+    for (int j{0}; j < leadingDimensions; ++j) {
       const Dimension &dim{GetDimension(j)};
       if (bytes != dim.ByteStride()) {
         return false;

diff  --git a/flang/runtime/dot-product.cpp b/flang/runtime/dot-product.cpp
index db790c392c11b..4b8029768b950 100644
--- a/flang/runtime/dot-product.cpp
+++ b/flang/runtime/dot-product.cpp
@@ -15,21 +15,29 @@
 
 namespace Fortran::runtime {
 
-template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
+// Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
+// argument; MATMUL does not.
+
+// General accumulator for any type and stride; this is not used for
+// contiguous numeric vectors.
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
 class Accumulator {
 public:
-  using Result = RESULT;
+  using Result = AccumulationType<RCAT, RKIND>;
   Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
-  void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
-    if constexpr (XCAT == TypeCategory::Complex) {
-      sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) *
-          static_cast<Result>(*y_.Element<YT>(&yAt));
-    } else if constexpr (XCAT == TypeCategory::Logical) {
+  void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) {
+    if constexpr (RCAT == TypeCategory::Logical) {
       sum_ = sum_ ||
           (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
     } else {
-      sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) *
-          static_cast<Result>(*y_.Element<YT>(&yAt));
+      const XT &xElement{*x_.Element<XT>(&xAt)};
+      const YT &yElement{*y_.Element<YT>(&yAt)};
+      if constexpr (RCAT == TypeCategory::Complex) {
+        sum_ += std::conj(static_cast<Result>(xElement)) *
+            static_cast<Result>(yElement);
+      } else {
+        sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement);
+      }
     }
   }
   Result GetResult() const { return sum_; }
@@ -39,9 +47,10 @@ class Accumulator {
   Result sum_{};
 };
 
-template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
-static inline RESULT DoDotProduct(
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+static inline CppTypeFor<RCAT, RKIND> DoDotProduct(
     const Descriptor &x, const Descriptor &y, Terminator &terminator) {
+  using Result = CppTypeFor<RCAT, RKIND>;
   RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
   SubscriptValue n{x.GetDimension(0).Extent()};
   if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
@@ -49,24 +58,48 @@ static inline RESULT DoDotProduct(
         "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
         static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
   }
-  if constexpr (std::is_same_v<XT, YT>) {
-    if constexpr (std::is_same_v<XT, float>) {
-      // TODO: call BLAS-1 SDOT or SDSDOT
-    } else if constexpr (std::is_same_v<XT, double>) {
-      // TODO: call BLAS-1 DDOT
-    } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
-      // TODO: call BLAS-1 CDOTC
-    } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
-      // TODO: call BLAS-1 ZDOTC
+  if constexpr (RCAT != TypeCategory::Logical) {
+    if (x.GetDimension(0).ByteStride() == sizeof(XT) &&
+        y.GetDimension(0).ByteStride() == sizeof(YT)) {
+      // Contiguous numeric vectors
+      if constexpr (std::is_same_v<XT, YT>) {
+        // Contiguous homogeneous numeric vectors
+        if constexpr (std::is_same_v<XT, float>) {
+          // TODO: call BLAS-1 SDOT or SDSDOT
+        } else if constexpr (std::is_same_v<XT, double>) {
+          // TODO: call BLAS-1 DDOT
+        } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+          // TODO: call BLAS-1 CDOTC
+        } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
+          // TODO: call BLAS-1 ZDOTC
+        }
+      }
+      XT *xp{x.OffsetElement<XT>(0)};
+      YT *yp{y.OffsetElement<YT>(0)};
+      using AccumType = AccumulationType<RCAT, RKIND>;
+      AccumType accum{};
+      if constexpr (RCAT == TypeCategory::Complex) {
+        for (SubscriptValue j{0}; j < n; ++j) {
+          accum += std::conj(static_cast<AccumType>(*xp++)) *
+              static_cast<AccumType>(*yp++);
+        }
+      } else {
+        for (SubscriptValue j{0}; j < n; ++j) {
+          accum +=
+              static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++);
+        }
+      }
+      return static_cast<Result>(accum);
     }
   }
+  // Non-contiguous, heterogeneous, & LOGICAL cases
   SubscriptValue xAt{x.GetDimension(0).LowerBound()};
   SubscriptValue yAt{y.GetDimension(0).LowerBound()};
-  Accumulator<RESULT, XCAT, XT, YT> accumulator{x, y};
+  Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
   for (SubscriptValue j{0}; j < n; ++j) {
-    accumulator.Accumulate(xAt++, yAt++);
+    accumulator.AccumulateIndexed(xAt++, yAt++);
   }
-  return accumulator.GetResult();
+  return static_cast<Result>(accumulator.GetResult());
 }
 
 template <TypeCategory RCAT, int RKIND> struct DotProduct {
@@ -79,7 +112,7 @@ template <TypeCategory RCAT, int RKIND> struct DotProduct {
                           GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
           if constexpr (resultType->first == RCAT &&
               resultType->second <= RKIND) {
-            return DoDotProduct<Result, XCAT, CppTypeFor<XCAT, XKIND>,
+            return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>,
                 CppTypeFor<YCAT, YKIND>>(x, y, terminator);
           }
         }
@@ -97,26 +130,32 @@ template <TypeCategory RCAT, int RKIND> struct DotProduct {
   Result operator()(const Descriptor &x, const Descriptor &y,
       const char *source, int line) const {
     Terminator terminator{source, line};
-    auto xCatKind{x.type().GetCategoryAndKind()};
-    auto yCatKind{y.type().GetCategoryAndKind()};
-    RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
-    return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, terminator,
-        x, y, terminator, yCatKind->first, yCatKind->second);
+    if (RCAT != TypeCategory::Logical && x.type() == y.type()) {
+      // No conversions needed, operands and result have same known type
+      return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}(
+          x, y, terminator);
+    } else {
+      auto xCatKind{x.type().GetCategoryAndKind()};
+      auto yCatKind{y.type().GetCategoryAndKind()};
+      RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
+      return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second,
+          terminator, x, y, terminator, yCatKind->first, yCatKind->second);
+    }
   }
 };
 
 extern "C" {
 std::int8_t RTNAME(DotProductInteger1)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
+  return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
 }
 std::int16_t RTNAME(DotProductInteger2)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
+  return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
 }
 std::int32_t RTNAME(DotProductInteger4)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
+  return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
 }
 std::int64_t RTNAME(DotProductInteger8)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
@@ -130,9 +169,10 @@ common::int128_t RTNAME(DotProductInteger16)(
 #endif
 
 // TODO: REAL/COMPLEX(2 & 3)
+// Intermediate results and operations are at least 64 bits
 float RTNAME(DotProductReal4)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
+  return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
 }
 double RTNAME(DotProductReal8)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
@@ -152,7 +192,7 @@ long double RTNAME(DotProductReal16)(
 
 void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  auto z{DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line)};
+  auto z{DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line)};
   result = std::complex<float>{
       static_cast<float>(z.real()), static_cast<float>(z.imag())};
 }

diff  --git a/flang/runtime/matmul.cpp b/flang/runtime/matmul.cpp
index ec1581456fcb9..2d0459c0f35a9 100644
--- a/flang/runtime/matmul.cpp
+++ b/flang/runtime/matmul.cpp
@@ -22,19 +22,19 @@
 #include "flang/Runtime/matmul.h"
 #include "terminator.h"
 #include "tools.h"
+#include "flang/Runtime/c-or-cpp.h"
 #include "flang/Runtime/cpp-type.h"
 #include "flang/Runtime/descriptor.h"
+#include <cstring>
 
 namespace Fortran::runtime {
 
+// General accumulator for any type and stride; this is not used for
+// contiguous numeric cases.
 template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
 class Accumulator {
 public:
-  // Accumulate floating-point results in (at least) double precision
-  using Result = CppTypeFor<RCAT,
-      RCAT == TypeCategory::Real || RCAT == TypeCategory::Complex
-          ? std::max(RKIND, static_cast<int>(sizeof(double)))
-          : RKIND>;
+  using Result = AccumulationType<RCAT, RKIND>;
   Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
   void Accumulate(const SubscriptValue xAt[], const SubscriptValue yAt[]) {
     if constexpr (RCAT == TypeCategory::Logical) {
@@ -52,6 +52,103 @@ class Accumulator {
   Result sum_{};
 };
 
+// Contiguous numeric matrix*matrix multiplication
+//   matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols)
+// Straightforward algorithm:
+//   DO 1 I = 1, NROWS
+//    DO 1 J = 1, NCOLS
+//     RES(I,J) = 0
+//     DO 1 K = 1, N
+//   1  RES(I,J) = RES(I,J) + X(I,K)*Y(K,J)
+// With loop distribution and transposition to avoid the inner sum
+// reduction and to avoid non-unit strides:
+//   DO 1 I = 1, NROWS
+//    DO 1 J = 1, NCOLS
+//   1 RES(I,J) = 0
+//   DO 2 K = 1, N
+//    DO 2 J = 1, NCOLS
+//     DO 2 I = 1, NROWS
+//   2  RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+inline void MatrixTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
+    SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT x,
+    const YT *RESTRICT y, SubscriptValue n) {
+  using ResultType = CppTypeFor<RCAT, RKIND>;
+  std::memset(product, 0, rows * cols * sizeof *product);
+  const XT *RESTRICT xp0{x};
+  for (SubscriptValue k{0}; k < n; ++k) {
+    ResultType *RESTRICT p{product};
+    for (SubscriptValue j{0}; j < cols; ++j) {
+      const XT *RESTRICT xp{xp0};
+      auto yv{static_cast<ResultType>(y[k + j * n])};
+      for (SubscriptValue i{0}; i < rows; ++i) {
+        *p++ += static_cast<ResultType>(*xp++) * yv;
+      }
+    }
+    xp0 += rows;
+  }
+}
+
+// Contiguous numeric matrix*vector multiplication
+//   matrix(rows,n) * column vector(n) -> column vector(rows)
+// Straightforward algorithm:
+//   DO 1 J = 1, NROWS
+//    RES(J) = 0
+//    DO 1 K = 1, N
+//   1 RES(J) = RES(J) + X(J,K)*Y(K)
+// With loop distribution and transposition to avoid the inner
+// sum reduction and to avoid non-unit strides:
+//   DO 1 J = 1, NROWS
+//   1 RES(J) = 0
+//   DO 2 K = 1, N
+//    DO 2 J = 1, NROWS
+//   2 RES(J) = RES(J) + X(J,K)*Y(K)
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+inline void MatrixTimesVector(CppTypeFor<RCAT, RKIND> *RESTRICT product,
+    SubscriptValue rows, SubscriptValue n, const XT *RESTRICT x,
+    const YT *RESTRICT y) {
+  using ResultType = CppTypeFor<RCAT, RKIND>;
+  std::memset(product, 0, rows * sizeof *product);
+  for (SubscriptValue k{0}; k < n; ++k) {
+    ResultType *RESTRICT p{product};
+    auto yv{static_cast<ResultType>(*y++)};
+    for (SubscriptValue j{0}; j < rows; ++j) {
+      *p++ += static_cast<ResultType>(*x++) * yv;
+    }
+  }
+}
+
+// Contiguous numeric vector*matrix multiplication
+//   row vector(n) * matrix(n,cols) -> row vector(cols)
+// Straightforward algorithm:
+//   DO 1 J = 1, NCOLS
+//    RES(J) = 0
+//    DO 1 K = 1, N
+//   1 RES(J) = RES(J) + X(K)*Y(K,J)
+// With loop distribution and transposition to avoid the inner
+// sum reduction and one non-unit stride (the other remains):
+//   DO 1 J = 1, NCOLS
+//   1 RES(J) = 0
+//   DO 2 K = 1, N
+//    DO 2 J = 1, NCOLS
+//   2 RES(J) = RES(J) + X(K)*Y(K,J)
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+inline void VectorTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
+    SubscriptValue n, SubscriptValue cols, const XT *RESTRICT x,
+    const YT *RESTRICT y) {
+  using ResultType = CppTypeFor<RCAT, RKIND>;
+  std::memset(product, 0, cols * sizeof *product);
+  for (SubscriptValue k{0}; k < n; ++k) {
+    ResultType *RESTRICT p{product};
+    auto xv{static_cast<ResultType>(*x++)};
+    const YT *RESTRICT yp{&y[k]};
+    for (SubscriptValue j{0}; j < cols; ++j) {
+      *p++ += xv * static_cast<ResultType>(*yp);
+      yp += n;
+    }
+  }
+}
+
 // Implements an instance of MATMUL for given argument types.
 template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
     typename YT>
@@ -79,36 +176,82 @@ static inline void DoMatmul(
     }
   } else {
     RUNTIME_CHECK(terminator, resRank == result.rank());
-    RUNTIME_CHECK(terminator, result.type() == (TypeCode{RCAT, RKIND}));
+    RUNTIME_CHECK(
+        terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND));
     RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]);
     RUNTIME_CHECK(terminator,
         resRank == 1 || result.GetDimension(1).Extent() == extent[1]);
   }
-  using WriteResult =
-      CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
-          RKIND>;
   SubscriptValue n{x.GetDimension(xRank - 1).Extent()};
   if (n != y.GetDimension(0).Extent()) {
     terminator.Crash("MATMUL: arrays do not conform (%jd != %jd)",
         static_cast<std::intmax_t>(n),
         static_cast<std::intmax_t>(y.GetDimension(0).Extent()));
   }
+  using WriteResult =
+      CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
+          RKIND>;
+  if constexpr (RCAT != TypeCategory::Logical) {
+    if (x.IsContiguous() && y.IsContiguous() &&
+        (IS_ALLOCATING || result.IsContiguous())) {
+      // Contiguous numeric matrices
+      if (resRank == 2) { // M*M -> M
+        if (std::is_same_v<XT, YT>) {
+          if constexpr (std::is_same_v<XT, float>) {
+            // TODO: call BLAS-3 SGEMM
+          } else if constexpr (std::is_same_v<XT, double>) {
+            // TODO: call BLAS-3 DGEMM
+          } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+            // TODO: call BLAS-3 CGEMM
+          } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
+            // TODO: call BLAS-3 ZGEMM
+          }
+        }
+        MatrixTimesMatrix<RCAT, RKIND, XT, YT>(
+            result.template OffsetElement<WriteResult>(), extent[0], extent[1],
+            x.OffsetElement<XT>(), y.OffsetElement<YT>(), n);
+        return;
+      } else if (xRank == 2) { // M*V -> V
+        if (std::is_same_v<XT, YT>) {
+          if constexpr (std::is_same_v<XT, float>) {
+            // TODO: call BLAS-2 SGEMV(x,y)
+          } else if constexpr (std::is_same_v<XT, double>) {
+            // TODO: call BLAS-2 DGEMV(x,y)
+          } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+            // TODO: call BLAS-2 CGEMV(x,y)
+          } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
+            // TODO: call BLAS-2 ZGEMV(x,y)
+          }
+        }
+        MatrixTimesVector<RCAT, RKIND, XT, YT>(
+            result.template OffsetElement<WriteResult>(), extent[0], n,
+            x.OffsetElement<XT>(), y.OffsetElement<YT>());
+        return;
+      } else { // V*M -> V
+        if (std::is_same_v<XT, YT>) {
+          if constexpr (std::is_same_v<XT, float>) {
+            // TODO: call BLAS-2 SGEMV(y,x)
+          } else if constexpr (std::is_same_v<XT, double>) {
+            // TODO: call BLAS-2 DGEMV(y,x)
+          } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+            // TODO: call BLAS-2 CGEMV(y,x)
+          } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
+            // TODO: call BLAS-2 ZGEMV(y,x)
+          }
+        }
+        VectorTimesMatrix<RCAT, RKIND, XT, YT>(
+            result.template OffsetElement<WriteResult>(), n, extent[0],
+            x.OffsetElement<XT>(), y.OffsetElement<YT>());
+        return;
+      }
+    }
+  }
+  // General algorithms for LOGICAL and noncontiguity
   SubscriptValue xAt[2], yAt[2], resAt[2];
   x.GetLowerBounds(xAt);
   y.GetLowerBounds(yAt);
   result.GetLowerBounds(resAt);
   if (resRank == 2) { // M*M -> M
-    if constexpr (std::is_same_v<XT, YT>) {
-      if constexpr (std::is_same_v<XT, float>) {
-        // TODO: call BLAS-3 SGEMM
-      } else if constexpr (std::is_same_v<XT, double>) {
-        // TODO: call BLAS-3 DGEMM
-      } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
-        // TODO: call BLAS-3 CGEMM
-      } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
-        // TODO: call BLAS-3 ZGEMM
-      }
-    }
     SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]};
     for (SubscriptValue i{0}; i < extent[0]; ++i) {
       for (SubscriptValue j{0}; j < extent[1]; ++j) {
@@ -125,44 +268,31 @@ static inline void DoMatmul(
       ++resAt[0];
       ++xAt[0];
     }
-  } else {
-    if constexpr (std::is_same_v<XT, YT>) {
-      if constexpr (std::is_same_v<XT, float>) {
-        // TODO: call BLAS-2 SGEMV
-      } else if constexpr (std::is_same_v<XT, double>) {
-        // TODO: call BLAS-2 DGEMV
-      } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
-        // TODO: call BLAS-2 CGEMV
-      } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
-        // TODO: call BLAS-2 ZGEMV
+  } else if (xRank == 2) { // M*V -> V
+    SubscriptValue x1{xAt[1]}, y0{yAt[0]};
+    for (SubscriptValue j{0}; j < extent[0]; ++j) {
+      Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
+      for (SubscriptValue k{0}; k < n; ++k) {
+        xAt[1] = x1 + k;
+        yAt[0] = y0 + k;
+        accumulator.Accumulate(xAt, yAt);
       }
+      *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
+      ++resAt[0];
+      ++xAt[0];
     }
-    if (xRank == 2) { // M*V -> V
-      SubscriptValue x1{xAt[1]}, y0{yAt[0]};
-      for (SubscriptValue j{0}; j < extent[0]; ++j) {
-        Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
-        for (SubscriptValue k{0}; k < n; ++k) {
-          xAt[1] = x1 + k;
-          yAt[0] = y0 + k;
-          accumulator.Accumulate(xAt, yAt);
-        }
-        *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
-        ++resAt[0];
-        ++xAt[0];
-      }
-    } else { // V*M -> V
-      SubscriptValue x0{xAt[0]}, y0{yAt[0]};
-      for (SubscriptValue j{0}; j < extent[0]; ++j) {
-        Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
-        for (SubscriptValue k{0}; k < n; ++k) {
-          xAt[0] = x0 + k;
-          yAt[0] = y0 + k;
-          accumulator.Accumulate(xAt, yAt);
-        }
-        *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
-        ++resAt[0];
-        ++yAt[1];
+  } else { // V*M -> V
+    SubscriptValue x0{xAt[0]}, y0{yAt[0]};
+    for (SubscriptValue j{0}; j < extent[0]; ++j) {
+      Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
+      for (SubscriptValue k{0}; k < n; ++k) {
+        xAt[0] = x0 + k;
+        yAt[0] = y0 + k;
+        accumulator.Accumulate(xAt, yAt);
       }
+      *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
+      ++resAt[0];
+      ++yAt[1];
     }
   }
 }

diff  --git a/flang/runtime/tools.h b/flang/runtime/tools.h
index ee2641b305b05..3e0a68b180172 100644
--- a/flang/runtime/tools.h
+++ b/flang/runtime/tools.h
@@ -334,5 +334,12 @@ std::optional<std::pair<TypeCategory, int>> inline constexpr GetResultType(
   return std::nullopt;
 }
 
+// Accumulate floating-point results in (at least) double precision
+template <TypeCategory CAT, int KIND>
+using AccumulationType = CppTypeFor<CAT,
+    CAT == TypeCategory::Real || CAT == TypeCategory::Complex
+        ? std::max(KIND, static_cast<int>(sizeof(double)))
+        : KIND>;
+
 } // namespace Fortran::runtime
 #endif // FORTRAN_RUNTIME_TOOLS_H_


        


More information about the flang-commits mailing list