[flang-commits] [flang] 5e1421b - [flang] Implement MATMUL in the runtime
peter klausler via flang-commits
flang-commits at lists.llvm.org
Tue May 18 11:00:02 PDT 2021
Author: peter klausler
Date: 2021-05-18T10:59:52-07:00
New Revision: 5e1421b22f642a6b34690d0d724e691ba3984836
URL: https://github.com/llvm/llvm-project/commit/5e1421b22f642a6b34690d0d724e691ba3984836
DIFF: https://github.com/llvm/llvm-project/commit/5e1421b22f642a6b34690d0d724e691ba3984836.diff
LOG: [flang] Implement MATMUL in the runtime
Define an API for the transformational intrinsic function MATMUL,
implement it, and add some basic unit tests. The large number of
possible argument type combinations are covered by a set of
generalized templates that are instantiated for each valid
pair of possible argument types.
Places where BLAS-2/3 routines could be called for acceleration
are marked with TODOs. Handling for other special cases (e.g.,
known-shape 3x3 matrices and vectors) are deferred.
Some minor tweaks were made to the recent related implementation
of DOT_PRODUCT to reflect lessons learned.
Differential Revision: https://reviews.llvm.org/D102652
Added:
flang/runtime/matmul.cpp
flang/runtime/matmul.h
flang/unittests/RuntimeGTest/Matmul.cpp
Modified:
flang/runtime/CMakeLists.txt
flang/runtime/dot-product.cpp
flang/runtime/reduction.h
flang/unittests/RuntimeGTest/CMakeLists.txt
Removed:
################################################################################
diff --git a/flang/runtime/CMakeLists.txt b/flang/runtime/CMakeLists.txt
index 84d13f12a1106..a484c94b0da1b 100644
--- a/flang/runtime/CMakeLists.txt
+++ b/flang/runtime/CMakeLists.txt
@@ -53,6 +53,7 @@ add_flang_library(FortranRuntime
io-error.cpp
io-stmt.cpp
main.cpp
+ matmul.cpp
memory.cpp
misc-intrinsic.cpp
namelist.cpp
diff --git a/flang/runtime/dot-product.cpp b/flang/runtime/dot-product.cpp
index 1c83d8de3bf3c..075d987b4de02 100644
--- a/flang/runtime/dot-product.cpp
+++ b/flang/runtime/dot-product.cpp
@@ -15,9 +15,33 @@
namespace Fortran::runtime {
-template <typename ACCUMULATOR>
-static inline auto DoDotProduct(const Descriptor &x, const Descriptor &y,
- Terminator &terminator) -> typename ACCUMULATOR::Result {
+template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
+class Accumulator {
+public:
+ using Result = RESULT;
+ Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
+ void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
+ if constexpr (XCAT == TypeCategory::Complex) {
+ sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) *
+ static_cast<Result>(*y_.Element<YT>(&yAt));
+ } else if constexpr (XCAT == TypeCategory::Logical) {
+ sum_ = sum_ ||
+ (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
+ } else {
+ sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) *
+ static_cast<Result>(*y_.Element<YT>(&yAt));
+ }
+ }
+ Result GetResult() const { return sum_; }
+
+private:
+ const Descriptor &x_, &y_;
+ Result sum_{};
+};
+
+template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
+static inline RESULT DoDotProduct(
+ const Descriptor &x, const Descriptor &y, Terminator &terminator) {
RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
SubscriptValue n{x.GetDimension(0).Extent()};
if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
@@ -25,18 +49,27 @@ static inline auto DoDotProduct(const Descriptor &x, const Descriptor &y,
"DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
}
+ if constexpr (std::is_same_v<XT, YT>) {
+ if constexpr (std::is_same_v<XT, float>) {
+ // TODO: call BLAS-1 SDOT or SDSDOT
+ } else if constexpr (std::is_same_v<XT, double>) {
+ // TODO: call BLAS-1 DDOT
+ } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+ // TODO: call BLAS-1 CDOTC
+ } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+ // TODO: call BLAS-1 ZDOTC
+ }
+ }
SubscriptValue xAt{x.GetDimension(0).LowerBound()};
SubscriptValue yAt{y.GetDimension(0).LowerBound()};
- ACCUMULATOR accumulator{x, y};
+ Accumulator<RESULT, XCAT, XT, YT> accumulator{x, y};
for (SubscriptValue j{0}; j < n; ++j) {
accumulator.Accumulate(xAt++, yAt++);
}
return accumulator.GetResult();
}
-template <TypeCategory RCAT, int RKIND,
- template <typename, TypeCategory, typename, typename> class ACCUM>
-struct DotProduct {
+template <TypeCategory RCAT, int RKIND> struct DotProduct {
using Result = CppTypeFor<RCAT, RKIND>;
template <TypeCategory XCAT, int XKIND> struct DP1 {
template <TypeCategory YCAT, int YKIND> struct DP2 {
@@ -46,9 +79,8 @@ struct DotProduct {
GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
if constexpr (resultType->first == RCAT &&
resultType->second <= RKIND) {
- using Accum = ACCUM<Result, XCAT, CppTypeFor<XCAT, XKIND>,
- CppTypeFor<YCAT, YKIND>>;
- return DoDotProduct<Accum>(x, y, terminator);
+ return DoDotProduct<Result, XCAT, CppTypeFor<XCAT, XKIND>,
+ CppTypeFor<YCAT, YKIND>>(x, y, terminator);
}
}
terminator.Crash(
@@ -73,127 +105,76 @@ struct DotProduct {
}
};
-template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
-class NumericAccumulator {
-public:
- using Result = RESULT;
- NumericAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
- void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
- if constexpr (XCAT == TypeCategory::Complex) {
- sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) *
- static_cast<Result>(*y_.Element<YT>(&yAt));
- } else {
- sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) *
- static_cast<Result>(*y_.Element<YT>(&yAt));
- }
- }
- Result GetResult() const { return sum_; }
-
-private:
- const Descriptor &x_, &y_;
- Result sum_{0};
-};
-
-template <typename, TypeCategory, typename XT, typename YT>
-class LogicalAccumulator {
-public:
- using Result = bool;
- LogicalAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
- void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
- result_ = result_ ||
- (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
- }
- bool GetResult() const { return result_; }
-
-private:
- const Descriptor &x_, &y_;
- bool result_{false};
-};
-
extern "C" {
std::int8_t RTNAME(DotProductInteger1)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
- x, y, source, line);
+ return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
}
std::int16_t RTNAME(DotProductInteger2)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
- x, y, source, line);
+ return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
}
std::int32_t RTNAME(DotProductInteger4)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
- x, y, source, line);
+ return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
}
std::int64_t RTNAME(DotProductInteger8)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
- x, y, source, line);
+ return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
}
#ifdef __SIZEOF_INT128__
common::int128_t RTNAME(DotProductInteger16)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- return DotProduct<TypeCategory::Integer, 16, NumericAccumulator>{}(
- x, y, source, line);
+ return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
}
#endif
// TODO: REAL/COMPLEX(2 & 3)
float RTNAME(DotProductReal4)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- return DotProduct<TypeCategory::Real, 8, NumericAccumulator>{}(
- x, y, source, line);
+ return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
}
double RTNAME(DotProductReal8)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- return DotProduct<TypeCategory::Real, 8, NumericAccumulator>{}(
- x, y, source, line);
+ return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
}
#if LONG_DOUBLE == 80
long double RTNAME(DotProductReal10)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- return DotProduct<TypeCategory::Real, 10, NumericAccumulator>{}(
- x, y, source, line);
+ return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
}
#elif LONG_DOUBLE == 128
long double RTNAME(DotProductReal16)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- return DotProduct<TypeCategory::Real, 16, NumericAccumulator>{}(
- x, y, source, line);
+ return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
}
#endif
void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- auto z{DotProduct<TypeCategory::Complex, 8, NumericAccumulator>{}(
- x, y, source, line)};
+ auto z{DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line)};
result = std::complex<float>{
static_cast<float>(z.real()), static_cast<float>(z.imag())};
}
void RTNAME(CppDotProductComplex8)(std::complex<double> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- result = DotProduct<TypeCategory::Complex, 8, NumericAccumulator>{}(
- x, y, source, line);
+ result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
}
#if LONG_DOUBLE == 80
void RTNAME(CppDotProductComplex10)(std::complex<long double> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- result = DotProduct<TypeCategory::Complex, 10, NumericAccumulator>{}(
- x, y, source, line);
+ result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
}
#elif LONG_DOUBLE == 128
void RTNAME(CppDotProductComplex16)(std::complex<long double> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- result = DotProduct<TypeCategory::Complex, 16, NumericAccumulator>{}(
- x, y, source, line);
+ result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
}
#endif
bool RTNAME(DotProductLogical)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
- return DotProduct<TypeCategory::Logical, 1, LogicalAccumulator>{}(
- x, y, source, line);
+ return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
}
} // extern "C"
} // namespace Fortran::runtime
diff --git a/flang/runtime/matmul.cpp b/flang/runtime/matmul.cpp
new file mode 100644
index 0000000000000..3d10ca0a31c6b
--- /dev/null
+++ b/flang/runtime/matmul.cpp
@@ -0,0 +1,220 @@
+//===-- runtime/matmul.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// Implements all forms of MATMUL (Fortran 2018 16.9.124)
+//
+// There are two main entry points; one establishes a descriptor for the
+// result and allocates it, and the other expects a result descriptor that
+// points to existing storage.
+//
+// This implementation must handle all combinations of numeric types and
+// kinds (100 - 165 cases depending on the target), plus all combinations
+// of logical kinds (16). A single template undergoes many instantiations
+// to cover all of the valid possibilities.
+//
+// Places where BLAS routines could be called are marked as TODO items.
+
+#include "matmul.h"
+#include "cpp-type.h"
+#include "descriptor.h"
+#include "terminator.h"
+#include "tools.h"
+
+namespace Fortran::runtime {
+
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+class Accumulator {
+public:
+ // Accumulate floating-point results in (at least) double precision
+ using Result = CppTypeFor<RCAT,
+ RCAT == TypeCategory::Real || RCAT == TypeCategory::Complex
+ ? std::max(RKIND, static_cast<int>(sizeof(double)))
+ : RKIND>;
+ Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
+ void Accumulate(const SubscriptValue xAt[], const SubscriptValue yAt[]) {
+ if constexpr (RCAT == TypeCategory::Logical) {
+ sum_ = sum_ ||
+ (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt));
+ } else {
+ sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) *
+ static_cast<Result>(*y_.Element<YT>(yAt));
+ }
+ }
+ Result GetResult() const { return sum_; }
+
+private:
+ const Descriptor &x_, &y_;
+ Result sum_{};
+};
+
+// Implements an instance of MATMUL for given argument types.
+template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
+ typename YT>
+static inline void DoMatmul(
+ std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result,
+ const Descriptor &x, const Descriptor &y, Terminator &terminator) {
+ int xRank{x.rank()};
+ int yRank{y.rank()};
+ int resRank{xRank + yRank - 2};
+ if (xRank * yRank != 2 * resRank) {
+ terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank);
+ }
+ SubscriptValue extent[2]{
+ xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(),
+ resRank == 2 ? y.GetDimension(1).Extent() : 0};
+ if constexpr (IS_ALLOCATING) {
+ result.Establish(
+ RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable);
+ for (int j{0}; j < resRank; ++j) {
+ result.GetDimension(j).SetBounds(1, extent[j]);
+ }
+ if (int stat{result.Allocate()}) {
+ terminator.Crash(
+ "MATMUL: could not allocate memory for result; STAT=%d", stat);
+ }
+ } else {
+ RUNTIME_CHECK(terminator, resRank == result.rank());
+ RUNTIME_CHECK(terminator, result.type() == (TypeCode{RCAT, RKIND}));
+ RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]);
+ RUNTIME_CHECK(terminator,
+ resRank == 1 || result.GetDimension(1).Extent() == extent[1]);
+ }
+ using WriteResult =
+ CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
+ RKIND>;
+ SubscriptValue n{x.GetDimension(xRank - 1).Extent()};
+ if (n != y.GetDimension(0).Extent()) {
+ terminator.Crash("MATMUL: arrays do not conform (%jd != %jd)",
+ static_cast<std::intmax_t>(n),
+ static_cast<std::intmax_t>(y.GetDimension(0).Extent()));
+ }
+ SubscriptValue xAt[2], yAt[2], resAt[2];
+ x.GetLowerBounds(xAt);
+ y.GetLowerBounds(yAt);
+ result.GetLowerBounds(resAt);
+ if (resRank == 2) { // M*M -> M
+ if constexpr (std::is_same_v<XT, YT>) {
+ if constexpr (std::is_same_v<XT, float>) {
+ // TODO: call BLAS-3 SGEMM
+ } else if constexpr (std::is_same_v<XT, double>) {
+ // TODO: call BLAS-3 DGEMM
+ } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+ // TODO: call BLAS-3 CGEMM
+ } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+ // TODO: call BLAS-3 ZGEMM
+ }
+ }
+ SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]};
+ for (SubscriptValue i{0}; i < extent[0]; ++i) {
+ for (SubscriptValue j{0}; j < extent[1]; ++j) {
+ Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
+ yAt[1] = y1 + j;
+ for (SubscriptValue k{0}; k < n; ++k) {
+ xAt[1] = x1 + k;
+ yAt[0] = y0 + k;
+ accumulator.Accumulate(xAt, yAt);
+ }
+ resAt[1] = res1 + j;
+ *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
+ }
+ ++resAt[0];
+ ++xAt[0];
+ }
+ } else {
+ if constexpr (std::is_same_v<XT, YT>) {
+ if constexpr (std::is_same_v<XT, float>) {
+ // TODO: call BLAS-2 SGEMV
+ } else if constexpr (std::is_same_v<XT, double>) {
+ // TODO: call BLAS-2 DGEMV
+ } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+ // TODO: call BLAS-2 CGEMV
+ } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+ // TODO: call BLAS-2 ZGEMV
+ }
+ }
+ if (xRank == 2) { // M*V -> V
+ SubscriptValue x1{xAt[1]}, y0{yAt[0]};
+ for (SubscriptValue j{0}; j < extent[0]; ++j) {
+ Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
+ for (SubscriptValue k{0}; k < n; ++k) {
+ xAt[1] = x1 + k;
+ yAt[0] = y0 + k;
+ accumulator.Accumulate(xAt, yAt);
+ }
+ *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
+ ++resAt[0];
+ ++xAt[0];
+ }
+ } else { // V*M -> V
+ SubscriptValue x0{xAt[0]}, y0{yAt[0]};
+ for (SubscriptValue j{0}; j < extent[0]; ++j) {
+ Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
+ for (SubscriptValue k{0}; k < n; ++k) {
+ xAt[0] = x0 + k;
+ yAt[0] = y0 + k;
+ accumulator.Accumulate(xAt, yAt);
+ }
+ *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
+ ++resAt[0];
+ ++yAt[1];
+ }
+ }
+ }
+}
+
+// Maps the dynamic type information from the arguments' descriptors
+// to the right instantiation of DoMatmul() for valid combinations of
+// types.
+template <bool IS_ALLOCATING> struct Matmul {
+ using ResultDescriptor =
+ std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
+ template <TypeCategory XCAT, int XKIND> struct MM1 {
+ template <TypeCategory YCAT, int YKIND> struct MM2 {
+ void operator()(ResultDescriptor &result, const Descriptor &x,
+ const Descriptor &y, Terminator &terminator) const {
+ if constexpr (constexpr auto resultType{
+ GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
+ if constexpr (common::IsNumericTypeCategory(resultType->first) ||
+ resultType->first == TypeCategory::Logical) {
+ return DoMatmul<IS_ALLOCATING, resultType->first,
+ resultType->second, CppTypeFor<XCAT, XKIND>,
+ CppTypeFor<YCAT, YKIND>>(result, x, y, terminator);
+ }
+ }
+ terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
+ static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
+ }
+ };
+ void operator()(ResultDescriptor &result, const Descriptor &x,
+ const Descriptor &y, Terminator &terminator, TypeCategory yCat,
+ int yKind) const {
+ ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator);
+ }
+ };
+ void operator()(ResultDescriptor &result, const Descriptor &x,
+ const Descriptor &y, const char *sourceFile, int line) const {
+ Terminator terminator{sourceFile, line};
+ auto xCatKind{x.type().GetCategoryAndKind()};
+ auto yCatKind{y.type().GetCategoryAndKind()};
+ RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
+ ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result,
+ x, y, terminator, yCatKind->first, yCatKind->second);
+ }
+};
+
+extern "C" {
+void RTNAME(Matmul)(Descriptor &result, const Descriptor &x,
+ const Descriptor &y, const char *sourceFile, int line) {
+ Matmul<true>{}(result, x, y, sourceFile, line);
+}
+void RTNAME(MatmulDirect)(const Descriptor &result, const Descriptor &x,
+ const Descriptor &y, const char *sourceFile, int line) {
+ Matmul<false>{}(result, x, y, sourceFile, line);
+}
+} // extern "C"
+} // namespace Fortran::runtime
diff --git a/flang/runtime/matmul.h b/flang/runtime/matmul.h
new file mode 100644
index 0000000000000..8334d6670a1b0
--- /dev/null
+++ b/flang/runtime/matmul.h
@@ -0,0 +1,29 @@
+//===-- runtime/matmul.h ----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// API for the transformational intrinsic function MATMUL.
+
+#ifndef FORTRAN_RUNTIME_MATMUL_H_
+#define FORTRAN_RUNTIME_MATMUL_H_
+#include "entry-names.h"
+namespace Fortran::runtime {
+class Descriptor;
+extern "C" {
+
+// The most general MATMUL. All type and shape information is taken from the
+// arguments' descriptors, and the result is dynamically allocated.
+void RTNAME(Matmul)(Descriptor &, const Descriptor &, const Descriptor &,
+ const char *sourceFile = nullptr, int line = 0);
+
+// A non-allocating variant; the result's descriptor must be established
+// and have a valid base address.
+void RTNAME(MatmulDirect)(const Descriptor &, const Descriptor &,
+ const Descriptor &, const char *sourceFile = nullptr, int line = 0);
+} // extern "C"
+} // namespace Fortran::runtime
+#endif // FORTRAN_RUNTIME_MATMUL_H_
diff --git a/flang/runtime/reduction.h b/flang/runtime/reduction.h
index cec30843da6d5..379fcb85cd1c5 100644
--- a/flang/runtime/reduction.h
+++ b/flang/runtime/reduction.h
@@ -7,9 +7,6 @@
//===----------------------------------------------------------------------===//
// Defines the API for the reduction transformational intrinsic functions.
-// (Except the complex-valued DOT_PRODUCT and the complex-valued total reduction
-// forms of SUM & PRODUCT; the API for those is in complex-reduction.h so that
-// C's _Complex can be used for their return types.)
#ifndef FORTRAN_RUNTIME_REDUCTION_H_
#define FORTRAN_RUNTIME_REDUCTION_H_
@@ -36,10 +33,10 @@ extern "C" {
// results in a caller-supplied descriptor, which is assumed to
// be large enough.
//
-// Complex-valued SUM and PRODUCT reductions have their API
-// entry points defined in complex-reduction.h; these are C wrappers
-// around C++ implementations so as to keep usage of C's _Complex
-// types out of C++ code.
+// Complex-valued SUM and PRODUCT reductions and complex-valued
+// DOT_PRODUCT have their API entry points defined in complex-reduction.h;
+// these here are C wrappers around C++ implementations so as to keep
+// usage of C's _Complex types out of C++ code.
// SUM()
diff --git a/flang/unittests/RuntimeGTest/CMakeLists.txt b/flang/unittests/RuntimeGTest/CMakeLists.txt
index cad827a8a9668..3d45cf6dc877b 100644
--- a/flang/unittests/RuntimeGTest/CMakeLists.txt
+++ b/flang/unittests/RuntimeGTest/CMakeLists.txt
@@ -2,6 +2,7 @@ add_flang_unittest(FlangRuntimeTests
CharacterTest.cpp
CrashHandlerFixture.cpp
Format.cpp
+ Matmul.cpp
MiscIntrinsic.cpp
Namelist.cpp
Numeric.cpp
diff --git a/flang/unittests/RuntimeGTest/Matmul.cpp b/flang/unittests/RuntimeGTest/Matmul.cpp
new file mode 100644
index 0000000000000..ae9e7a84236c8
--- /dev/null
+++ b/flang/unittests/RuntimeGTest/Matmul.cpp
@@ -0,0 +1,98 @@
+//===-- flang/unittests/RuntimeGTest/Matmul.cpp---- -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../../runtime/matmul.h"
+#include "gtest/gtest.h"
+#include "tools.h"
+#include "../../runtime/allocatable.h"
+#include "../../runtime/cpp-type.h"
+#include "../../runtime/descriptor.h"
+#include "../../runtime/type-code.h"
+
+using namespace Fortran::runtime;
+using Fortran::common::TypeCategory;
+
+TEST(Matmul, Basic) {
+ // X 0 2 4 Y 6 9 V -1 -2
+ // 1 3 5 7 10
+ // 8 11
+ auto x{MakeArray<TypeCategory::Integer, 4>(
+ std::vector<int>{2, 3}, std::vector<std::int32_t>{0, 1, 2, 3, 4, 5})};
+ auto y{MakeArray<TypeCategory::Integer, 2>(
+ std::vector<int>{3, 2}, std::vector<std::int16_t>{6, 7, 8, 9, 10, 11})};
+ auto v{MakeArray<TypeCategory::Integer, 8>(
+ std::vector<int>{2}, std::vector<std::int64_t>{-1, -2})};
+ StaticDescriptor<2> statDesc;
+ Descriptor &result{statDesc.descriptor()};
+
+ RTNAME(Matmul)(result, *x, *y, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+
+ std::memset(
+ result.raw().base_addr, 0, result.Elements() * result.ElementBytes());
+ result.GetDimension(0).SetLowerBound(0);
+ result.GetDimension(1).SetLowerBound(2);
+ RTNAME(MatmulDirect)(result, *x, *y, __FILE__, __LINE__);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+ result.Destroy();
+
+ RTNAME(Matmul)(result, *v, *x, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 1);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
+ result.Destroy();
+
+ RTNAME(Matmul)(result, *y, *v, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 1);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
+ result.Destroy();
+
+ // X F F T Y F T
+ // F T T F T
+ // F F
+ auto xLog{MakeArray<TypeCategory::Logical, 1>(std::vector<int>{2, 3},
+ std::vector<std::uint8_t>{false, false, false, true, true, false})};
+ auto yLog{MakeArray<TypeCategory::Logical, 2>(std::vector<int>{3, 2},
+ std::vector<std::uint16_t>{false, false, false, true, true, false})};
+ RTNAME(Matmul)(result, *xLog, *yLog, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
+ EXPECT_FALSE(
+ static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
+ EXPECT_FALSE(
+ static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
+ EXPECT_FALSE(
+ static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
+ EXPECT_TRUE(
+ static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
+}
More information about the flang-commits
mailing list