[flang-commits] [flang] a5a493e - [flang] Speed common runtime cases of DOT_PRODUCT & MATMUL
peter klausler via flang-commits
flang-commits at lists.llvm.org
Fri Oct 22 14:41:00 PDT 2021
Author: Peter Klausler
Date: 2021-10-22T14:36:13-07:00
New Revision: a5a493e1920572ca78b77fcdaef68c26e96d25e7
URL: https://github.com/llvm/llvm-project/commit/a5a493e1920572ca78b77fcdaef68c26e96d25e7
DIFF: https://github.com/llvm/llvm-project/commit/a5a493e1920572ca78b77fcdaef68c26e96d25e7.diff
LOG: [flang] Speed common runtime cases of DOT_PRODUCT & MATMUL
Look for contiguous numeric argument arrays at runtime and
use specialized code for them.
Differential Revision: https://reviews.llvm.org/D112239
Added:
Modified:
flang/include/flang/Runtime/c-or-cpp.h
flang/include/flang/Runtime/descriptor.h
flang/runtime/dot-product.cpp
flang/runtime/matmul.cpp
flang/runtime/tools.h
Removed:
################################################################################
diff --git a/flang/include/flang/Runtime/c-or-cpp.h b/flang/include/flang/Runtime/c-or-cpp.h
index 4babd885cad32..8bac523907750 100644
--- a/flang/include/flang/Runtime/c-or-cpp.h
+++ b/flang/include/flang/Runtime/c-or-cpp.h
@@ -13,11 +13,13 @@
#define IF_CPLUSPLUS(x) x
#define IF_NOT_CPLUSPLUS(x)
#define DEFAULT_VALUE(x) = (x)
+#define RESTRICT __restrict
#else
#include <stdbool.h>
#define IF_CPLUSPLUS(x)
#define IF_NOT_CPLUSPLUS(x) x
#define DEFAULT_VALUE(x)
+#define RESTRICT restrict
#endif
#define FORTRAN_EXTERN_C_BEGIN IF_CPLUSPLUS(extern "C" {)
diff --git a/flang/include/flang/Runtime/descriptor.h b/flang/include/flang/Runtime/descriptor.h
index 2b927df3bcd29..75c5e2176d929 100644
--- a/flang/include/flang/Runtime/descriptor.h
+++ b/flang/include/flang/Runtime/descriptor.h
@@ -304,7 +304,10 @@ class Descriptor {
bool IsContiguous(int leadingDimensions = maxRank) const {
auto bytes{static_cast<SubscriptValue>(ElementBytes())};
- for (int j{0}; j < leadingDimensions && j < raw_.rank; ++j) {
+ if (leadingDimensions > raw_.rank) {
+ leadingDimensions = raw_.rank;
+ }
+ for (int j{0}; j < leadingDimensions; ++j) {
const Dimension &dim{GetDimension(j)};
if (bytes != dim.ByteStride()) {
return false;
diff --git a/flang/runtime/dot-product.cpp b/flang/runtime/dot-product.cpp
index db790c392c11b..4b8029768b950 100644
--- a/flang/runtime/dot-product.cpp
+++ b/flang/runtime/dot-product.cpp
@@ -15,21 +15,29 @@
namespace Fortran::runtime {
-template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
+// Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first
+// argument; MATMUL does not.
+
+// General accumulator for any type and stride; this is not used for
+// contiguous numeric vectors.
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
class Accumulator {
public:
- using Result = RESULT;
+ using Result = AccumulationType<RCAT, RKIND>;
Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
- void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
- if constexpr (XCAT == TypeCategory::Complex) {
- sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) *
- static_cast<Result>(*y_.Element<YT>(&yAt));
- } else if constexpr (XCAT == TypeCategory::Logical) {
+ void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) {
+ if constexpr (RCAT == TypeCategory::Logical) {
sum_ = sum_ ||
(IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
} else {
- sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) *
- static_cast<Result>(*y_.Element<YT>(&yAt));
+ const XT &xElement{*x_.Element<XT>(&xAt)};
+ const YT &yElement{*y_.Element<YT>(&yAt)};
+ if constexpr (RCAT == TypeCategory::Complex) {
+ sum_ += std::conj(static_cast<Result>(xElement)) *
+ static_cast<Result>(yElement);
+ } else {
+ sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement);
+ }
}
}
Result GetResult() const { return sum_; }
@@ -39,9 +47,10 @@ class Accumulator {
Result sum_{};
};
-template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
-static inline RESULT DoDotProduct(
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+static inline CppTypeFor<RCAT, RKIND> DoDotProduct(
const Descriptor &x, const Descriptor &y, Terminator &terminator) {
+ using Result = CppTypeFor<RCAT, RKIND>;
RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
SubscriptValue n{x.GetDimension(0).Extent()};
if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
@@ -49,24 +58,48 @@ static inline RESULT DoDotProduct(
"DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
}
- if constexpr (std::is_same_v<XT, YT>) {
- if constexpr (std::is_same_v<XT, float>) {
- // TODO: call BLAS-1 SDOT or SDSDOT
- } else if constexpr (std::is_same_v<XT, double>) {
- // TODO: call BLAS-1 DDOT
- } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
- // TODO: call BLAS-1 CDOTC
- } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
- // TODO: call BLAS-1 ZDOTC
+ if constexpr (RCAT != TypeCategory::Logical) {
+ if (x.GetDimension(0).ByteStride() == sizeof(XT) &&
+ y.GetDimension(0).ByteStride() == sizeof(YT)) {
+ // Contiguous numeric vectors
+ if constexpr (std::is_same_v<XT, YT>) {
+ // Contiguous homogeneous numeric vectors
+ if constexpr (std::is_same_v<XT, float>) {
+ // TODO: call BLAS-1 SDOT or SDSDOT
+ } else if constexpr (std::is_same_v<XT, double>) {
+ // TODO: call BLAS-1 DDOT
+ } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+ // TODO: call BLAS-1 CDOTC
+ } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
+ // TODO: call BLAS-1 ZDOTC
+ }
+ }
+ XT *xp{x.OffsetElement<XT>(0)};
+ YT *yp{y.OffsetElement<YT>(0)};
+ using AccumType = AccumulationType<RCAT, RKIND>;
+ AccumType accum{};
+ if constexpr (RCAT == TypeCategory::Complex) {
+ for (SubscriptValue j{0}; j < n; ++j) {
+ accum += std::conj(static_cast<AccumType>(*xp++)) *
+ static_cast<AccumType>(*yp++);
+ }
+ } else {
+ for (SubscriptValue j{0}; j < n; ++j) {
+ accum +=
+ static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++);
+ }
+ }
+ return static_cast<Result>(accum);
}
}
+ // Non-contiguous, heterogeneous, & LOGICAL cases
SubscriptValue xAt{x.GetDimension(0).LowerBound()};
SubscriptValue yAt{y.GetDimension(0).LowerBound()};
- Accumulator<RESULT, XCAT, XT, YT> accumulator{x, y};
+ Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
for (SubscriptValue j{0}; j < n; ++j) {
- accumulator.Accumulate(xAt++, yAt++);
+ accumulator.AccumulateIndexed(xAt++, yAt++);
}
- return accumulator.GetResult();
+ return static_cast<Result>(accumulator.GetResult());
}
template <TypeCategory RCAT, int RKIND> struct DotProduct {
@@ -79,7 +112,7 @@ template <TypeCategory RCAT, int RKIND> struct DotProduct {
GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
if constexpr (resultType->first == RCAT &&
resultType->second <= RKIND) {
- return DoDotProduct<Result, XCAT, CppTypeFor<XCAT, XKIND>,
+ return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>,
CppTypeFor<YCAT, YKIND>>(x, y, terminator);
}
}
@@ -97,26 +130,32 @@ template <TypeCategory RCAT, int RKIND> struct DotProduct {
Result operator()(const Descriptor &x, const Descriptor &y,
const char *source, int line) const {
Terminator terminator{source, line};
- auto xCatKind{x.type().GetCategoryAndKind()};
- auto yCatKind{y.type().GetCategoryAndKind()};
- RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
- return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, terminator,
- x, y, terminator, yCatKind->first, yCatKind->second);
+ if (RCAT != TypeCategory::Logical && x.type() == y.type()) {
+ // No conversions needed, operands and result have same known type
+ return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}(
+ x, y, terminator);
+ } else {
+ auto xCatKind{x.type().GetCategoryAndKind()};
+ auto yCatKind{y.type().GetCategoryAndKind()};
+ RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
+ return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second,
+ terminator, x, y, terminator, yCatKind->first, yCatKind->second);
+ }
}
};
extern "C" {
std::int8_t RTNAME(DotProductInteger1)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
+ return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line);
}
std::int16_t RTNAME(DotProductInteger2)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
+ return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line);
}
std::int32_t RTNAME(DotProductInteger4)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
+ return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line);
}
std::int64_t RTNAME(DotProductInteger8)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
@@ -130,9 +169,10 @@ common::int128_t RTNAME(DotProductInteger16)(
#endif
// TODO: REAL/COMPLEX(2 & 3)
+// Intermediate results and operations are at least 64 bits
float RTNAME(DotProductReal4)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
+ return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line);
}
double RTNAME(DotProductReal8)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
@@ -152,7 +192,7 @@ long double RTNAME(DotProductReal16)(
void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- auto z{DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line)};
+ auto z{DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line)};
result = std::complex<float>{
static_cast<float>(z.real()), static_cast<float>(z.imag())};
}
diff --git a/flang/runtime/matmul.cpp b/flang/runtime/matmul.cpp
index ec1581456fcb9..2d0459c0f35a9 100644
--- a/flang/runtime/matmul.cpp
+++ b/flang/runtime/matmul.cpp
@@ -22,19 +22,19 @@
#include "flang/Runtime/matmul.h"
#include "terminator.h"
#include "tools.h"
+#include "flang/Runtime/c-or-cpp.h"
#include "flang/Runtime/cpp-type.h"
#include "flang/Runtime/descriptor.h"
+#include <cstring>
namespace Fortran::runtime {
+// General accumulator for any type and stride; this is not used for
+// contiguous numeric cases.
template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
class Accumulator {
public:
- // Accumulate floating-point results in (at least) double precision
- using Result = CppTypeFor<RCAT,
- RCAT == TypeCategory::Real || RCAT == TypeCategory::Complex
- ? std::max(RKIND, static_cast<int>(sizeof(double)))
- : RKIND>;
+ using Result = AccumulationType<RCAT, RKIND>;
Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
void Accumulate(const SubscriptValue xAt[], const SubscriptValue yAt[]) {
if constexpr (RCAT == TypeCategory::Logical) {
@@ -52,6 +52,103 @@ class Accumulator {
Result sum_{};
};
+// Contiguous numeric matrix*matrix multiplication
+// matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols)
+// Straightforward algorithm:
+// DO 1 I = 1, NROWS
+// DO 1 J = 1, NCOLS
+// RES(I,J) = 0
+// DO 1 K = 1, N
+// 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J)
+// With loop distribution and transposition to avoid the inner sum
+// reduction and to avoid non-unit strides:
+// DO 1 I = 1, NROWS
+// DO 1 J = 1, NCOLS
+// 1 RES(I,J) = 0
+// DO 2 K = 1, N
+// DO 2 J = 1, NCOLS
+// DO 2 I = 1, NROWS
+// 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+inline void MatrixTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
+ SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT x,
+ const YT *RESTRICT y, SubscriptValue n) {
+ using ResultType = CppTypeFor<RCAT, RKIND>;
+ std::memset(product, 0, rows * cols * sizeof *product);
+ const XT *RESTRICT xp0{x};
+ for (SubscriptValue k{0}; k < n; ++k) {
+ ResultType *RESTRICT p{product};
+ for (SubscriptValue j{0}; j < cols; ++j) {
+ const XT *RESTRICT xp{xp0};
+ auto yv{static_cast<ResultType>(y[k + j * n])};
+ for (SubscriptValue i{0}; i < rows; ++i) {
+ *p++ += static_cast<ResultType>(*xp++) * yv;
+ }
+ }
+ xp0 += rows;
+ }
+}
+
+// Contiguous numeric matrix*vector multiplication
+// matrix(rows,n) * column vector(n) -> column vector(rows)
+// Straightforward algorithm:
+// DO 1 J = 1, NROWS
+// RES(J) = 0
+// DO 1 K = 1, N
+// 1 RES(J) = RES(J) + X(J,K)*Y(K)
+// With loop distribution and transposition to avoid the inner
+// sum reduction and to avoid non-unit strides:
+// DO 1 J = 1, NROWS
+// 1 RES(J) = 0
+// DO 2 K = 1, N
+// DO 2 J = 1, NROWS
+// 2 RES(J) = RES(J) + X(J,K)*Y(K)
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+inline void MatrixTimesVector(CppTypeFor<RCAT, RKIND> *RESTRICT product,
+ SubscriptValue rows, SubscriptValue n, const XT *RESTRICT x,
+ const YT *RESTRICT y) {
+ using ResultType = CppTypeFor<RCAT, RKIND>;
+ std::memset(product, 0, rows * sizeof *product);
+ for (SubscriptValue k{0}; k < n; ++k) {
+ ResultType *RESTRICT p{product};
+ auto yv{static_cast<ResultType>(*y++)};
+ for (SubscriptValue j{0}; j < rows; ++j) {
+ *p++ += static_cast<ResultType>(*x++) * yv;
+ }
+ }
+}
+
+// Contiguous numeric vector*matrix multiplication
+// row vector(n) * matrix(n,cols) -> row vector(cols)
+// Straightforward algorithm:
+// DO 1 J = 1, NCOLS
+// RES(J) = 0
+// DO 1 K = 1, N
+// 1 RES(J) = RES(J) + X(K)*Y(K,J)
+// With loop distribution and transposition to avoid the inner
+// sum reduction and one non-unit stride (the other remains):
+// DO 1 J = 1, NCOLS
+// 1 RES(J) = 0
+// DO 2 K = 1, N
+// DO 2 J = 1, NCOLS
+// 2 RES(J) = RES(J) + X(K)*Y(K,J)
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+inline void VectorTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
+ SubscriptValue n, SubscriptValue cols, const XT *RESTRICT x,
+ const YT *RESTRICT y) {
+ using ResultType = CppTypeFor<RCAT, RKIND>;
+ std::memset(product, 0, cols * sizeof *product);
+ for (SubscriptValue k{0}; k < n; ++k) {
+ ResultType *RESTRICT p{product};
+ auto xv{static_cast<ResultType>(*x++)};
+ const YT *RESTRICT yp{&y[k]};
+ for (SubscriptValue j{0}; j < cols; ++j) {
+ *p++ += xv * static_cast<ResultType>(*yp);
+ yp += n;
+ }
+ }
+}
+
// Implements an instance of MATMUL for given argument types.
template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
typename YT>
@@ -79,36 +176,82 @@ static inline void DoMatmul(
}
} else {
RUNTIME_CHECK(terminator, resRank == result.rank());
- RUNTIME_CHECK(terminator, result.type() == (TypeCode{RCAT, RKIND}));
+ RUNTIME_CHECK(
+ terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND));
RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]);
RUNTIME_CHECK(terminator,
resRank == 1 || result.GetDimension(1).Extent() == extent[1]);
}
- using WriteResult =
- CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
- RKIND>;
SubscriptValue n{x.GetDimension(xRank - 1).Extent()};
if (n != y.GetDimension(0).Extent()) {
terminator.Crash("MATMUL: arrays do not conform (%jd != %jd)",
static_cast<std::intmax_t>(n),
static_cast<std::intmax_t>(y.GetDimension(0).Extent()));
}
+ using WriteResult =
+ CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
+ RKIND>;
+ if constexpr (RCAT != TypeCategory::Logical) {
+ if (x.IsContiguous() && y.IsContiguous() &&
+ (IS_ALLOCATING || result.IsContiguous())) {
+ // Contiguous numeric matrices
+ if (resRank == 2) { // M*M -> M
+ if (std::is_same_v<XT, YT>) {
+ if constexpr (std::is_same_v<XT, float>) {
+ // TODO: call BLAS-3 SGEMM
+ } else if constexpr (std::is_same_v<XT, double>) {
+ // TODO: call BLAS-3 DGEMM
+ } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+ // TODO: call BLAS-3 CGEMM
+ } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
+ // TODO: call BLAS-3 ZGEMM
+ }
+ }
+ MatrixTimesMatrix<RCAT, RKIND, XT, YT>(
+ result.template OffsetElement<WriteResult>(), extent[0], extent[1],
+ x.OffsetElement<XT>(), y.OffsetElement<YT>(), n);
+ return;
+ } else if (xRank == 2) { // M*V -> V
+ if (std::is_same_v<XT, YT>) {
+ if constexpr (std::is_same_v<XT, float>) {
+ // TODO: call BLAS-2 SGEMV(x,y)
+ } else if constexpr (std::is_same_v<XT, double>) {
+ // TODO: call BLAS-2 DGEMV(x,y)
+ } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+ // TODO: call BLAS-2 CGEMV(x,y)
+ } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
+ // TODO: call BLAS-2 ZGEMV(x,y)
+ }
+ }
+ MatrixTimesVector<RCAT, RKIND, XT, YT>(
+ result.template OffsetElement<WriteResult>(), extent[0], n,
+ x.OffsetElement<XT>(), y.OffsetElement<YT>());
+ return;
+ } else { // V*M -> V
+ if (std::is_same_v<XT, YT>) {
+ if constexpr (std::is_same_v<XT, float>) {
+ // TODO: call BLAS-2 SGEMV(y,x)
+ } else if constexpr (std::is_same_v<XT, double>) {
+ // TODO: call BLAS-2 DGEMV(y,x)
+ } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+ // TODO: call BLAS-2 CGEMV(y,x)
+ } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
+ // TODO: call BLAS-2 ZGEMV(y,x)
+ }
+ }
+ VectorTimesMatrix<RCAT, RKIND, XT, YT>(
+ result.template OffsetElement<WriteResult>(), n, extent[0],
+ x.OffsetElement<XT>(), y.OffsetElement<YT>());
+ return;
+ }
+ }
+ }
+ // General algorithms for LOGICAL and noncontiguity
SubscriptValue xAt[2], yAt[2], resAt[2];
x.GetLowerBounds(xAt);
y.GetLowerBounds(yAt);
result.GetLowerBounds(resAt);
if (resRank == 2) { // M*M -> M
- if constexpr (std::is_same_v<XT, YT>) {
- if constexpr (std::is_same_v<XT, float>) {
- // TODO: call BLAS-3 SGEMM
- } else if constexpr (std::is_same_v<XT, double>) {
- // TODO: call BLAS-3 DGEMM
- } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
- // TODO: call BLAS-3 CGEMM
- } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
- // TODO: call BLAS-3 ZGEMM
- }
- }
SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]};
for (SubscriptValue i{0}; i < extent[0]; ++i) {
for (SubscriptValue j{0}; j < extent[1]; ++j) {
@@ -125,44 +268,31 @@ static inline void DoMatmul(
++resAt[0];
++xAt[0];
}
- } else {
- if constexpr (std::is_same_v<XT, YT>) {
- if constexpr (std::is_same_v<XT, float>) {
- // TODO: call BLAS-2 SGEMV
- } else if constexpr (std::is_same_v<XT, double>) {
- // TODO: call BLAS-2 DGEMV
- } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
- // TODO: call BLAS-2 CGEMV
- } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
- // TODO: call BLAS-2 ZGEMV
+ } else if (xRank == 2) { // M*V -> V
+ SubscriptValue x1{xAt[1]}, y0{yAt[0]};
+ for (SubscriptValue j{0}; j < extent[0]; ++j) {
+ Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
+ for (SubscriptValue k{0}; k < n; ++k) {
+ xAt[1] = x1 + k;
+ yAt[0] = y0 + k;
+ accumulator.Accumulate(xAt, yAt);
}
+ *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
+ ++resAt[0];
+ ++xAt[0];
}
- if (xRank == 2) { // M*V -> V
- SubscriptValue x1{xAt[1]}, y0{yAt[0]};
- for (SubscriptValue j{0}; j < extent[0]; ++j) {
- Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
- for (SubscriptValue k{0}; k < n; ++k) {
- xAt[1] = x1 + k;
- yAt[0] = y0 + k;
- accumulator.Accumulate(xAt, yAt);
- }
- *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
- ++resAt[0];
- ++xAt[0];
- }
- } else { // V*M -> V
- SubscriptValue x0{xAt[0]}, y0{yAt[0]};
- for (SubscriptValue j{0}; j < extent[0]; ++j) {
- Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
- for (SubscriptValue k{0}; k < n; ++k) {
- xAt[0] = x0 + k;
- yAt[0] = y0 + k;
- accumulator.Accumulate(xAt, yAt);
- }
- *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
- ++resAt[0];
- ++yAt[1];
+ } else { // V*M -> V
+ SubscriptValue x0{xAt[0]}, y0{yAt[0]};
+ for (SubscriptValue j{0}; j < extent[0]; ++j) {
+ Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
+ for (SubscriptValue k{0}; k < n; ++k) {
+ xAt[0] = x0 + k;
+ yAt[0] = y0 + k;
+ accumulator.Accumulate(xAt, yAt);
}
+ *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
+ ++resAt[0];
+ ++yAt[1];
}
}
}
diff --git a/flang/runtime/tools.h b/flang/runtime/tools.h
index ee2641b305b05..3e0a68b180172 100644
--- a/flang/runtime/tools.h
+++ b/flang/runtime/tools.h
@@ -334,5 +334,12 @@ std::optional<std::pair<TypeCategory, int>> inline constexpr GetResultType(
return std::nullopt;
}
+// Accumulate floating-point results in (at least) double precision
+template <TypeCategory CAT, int KIND>
+using AccumulationType = CppTypeFor<CAT,
+ CAT == TypeCategory::Real || CAT == TypeCategory::Complex
+ ? std::max(KIND, static_cast<int>(sizeof(double)))
+ : KIND>;
+
} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_TOOLS_H_
More information about the flang-commits
mailing list