[flang-commits] [flang] 5e1421b - [flang] Implement MATMUL in the runtime

peter klausler via flang-commits flang-commits at lists.llvm.org
Tue May 18 11:00:02 PDT 2021


Author: peter klausler
Date: 2021-05-18T10:59:52-07:00
New Revision: 5e1421b22f642a6b34690d0d724e691ba3984836

URL: https://github.com/llvm/llvm-project/commit/5e1421b22f642a6b34690d0d724e691ba3984836
DIFF: https://github.com/llvm/llvm-project/commit/5e1421b22f642a6b34690d0d724e691ba3984836.diff

LOG: [flang] Implement MATMUL in the runtime

Define an API for the transformational intrinsic function MATMUL,
implement it, and add some basic unit tests.  The large number of
possible argument type combinations are covered by a set of
generalized templates that are instantiated for each valid
pair of possible argument types.

Places where BLAS-2/3 routines could be called for acceleration
are marked with TODOs.  Handling for other special cases (e.g.,
known-shape 3x3 matrices and vectors) are deferred.

Some minor tweaks were made to the recent related implementation
of DOT_PRODUCT to reflect lessons learned.

Differential Revision: https://reviews.llvm.org/D102652

Added: 
    flang/runtime/matmul.cpp
    flang/runtime/matmul.h
    flang/unittests/RuntimeGTest/Matmul.cpp

Modified: 
    flang/runtime/CMakeLists.txt
    flang/runtime/dot-product.cpp
    flang/runtime/reduction.h
    flang/unittests/RuntimeGTest/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/flang/runtime/CMakeLists.txt b/flang/runtime/CMakeLists.txt
index 84d13f12a1106..a484c94b0da1b 100644
--- a/flang/runtime/CMakeLists.txt
+++ b/flang/runtime/CMakeLists.txt
@@ -53,6 +53,7 @@ add_flang_library(FortranRuntime
   io-error.cpp
   io-stmt.cpp
   main.cpp
+  matmul.cpp
   memory.cpp
   misc-intrinsic.cpp
   namelist.cpp

diff  --git a/flang/runtime/dot-product.cpp b/flang/runtime/dot-product.cpp
index 1c83d8de3bf3c..075d987b4de02 100644
--- a/flang/runtime/dot-product.cpp
+++ b/flang/runtime/dot-product.cpp
@@ -15,9 +15,33 @@
 
 namespace Fortran::runtime {
 
-template <typename ACCUMULATOR>
-static inline auto DoDotProduct(const Descriptor &x, const Descriptor &y,
-    Terminator &terminator) -> typename ACCUMULATOR::Result {
+template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
+class Accumulator {
+public:
+  using Result = RESULT;
+  Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
+  void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
+    if constexpr (XCAT == TypeCategory::Complex) {
+      sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) *
+          static_cast<Result>(*y_.Element<YT>(&yAt));
+    } else if constexpr (XCAT == TypeCategory::Logical) {
+      sum_ = sum_ ||
+          (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
+    } else {
+      sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) *
+          static_cast<Result>(*y_.Element<YT>(&yAt));
+    }
+  }
+  Result GetResult() const { return sum_; }
+
+private:
+  const Descriptor &x_, &y_;
+  Result sum_{};
+};
+
+template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
+static inline RESULT DoDotProduct(
+    const Descriptor &x, const Descriptor &y, Terminator &terminator) {
   RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
   SubscriptValue n{x.GetDimension(0).Extent()};
   if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
@@ -25,18 +49,27 @@ static inline auto DoDotProduct(const Descriptor &x, const Descriptor &y,
         "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
         static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
   }
+  if constexpr (std::is_same_v<XT, YT>) {
+    if constexpr (std::is_same_v<XT, float>) {
+      // TODO: call BLAS-1 SDOT or SDSDOT
+    } else if constexpr (std::is_same_v<XT, double>) {
+      // TODO: call BLAS-1 DDOT
+    } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+      // TODO: call BLAS-1 CDOTC
+    } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+      // TODO: call BLAS-1 ZDOTC
+    }
+  }
   SubscriptValue xAt{x.GetDimension(0).LowerBound()};
   SubscriptValue yAt{y.GetDimension(0).LowerBound()};
-  ACCUMULATOR accumulator{x, y};
+  Accumulator<RESULT, XCAT, XT, YT> accumulator{x, y};
   for (SubscriptValue j{0}; j < n; ++j) {
     accumulator.Accumulate(xAt++, yAt++);
   }
   return accumulator.GetResult();
 }
 
-template <TypeCategory RCAT, int RKIND,
-    template <typename, TypeCategory, typename, typename> class ACCUM>
-struct DotProduct {
+template <TypeCategory RCAT, int RKIND> struct DotProduct {
   using Result = CppTypeFor<RCAT, RKIND>;
   template <TypeCategory XCAT, int XKIND> struct DP1 {
     template <TypeCategory YCAT, int YKIND> struct DP2 {
@@ -46,9 +79,8 @@ struct DotProduct {
                           GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
           if constexpr (resultType->first == RCAT &&
               resultType->second <= RKIND) {
-            using Accum = ACCUM<Result, XCAT, CppTypeFor<XCAT, XKIND>,
-                CppTypeFor<YCAT, YKIND>>;
-            return DoDotProduct<Accum>(x, y, terminator);
+            return DoDotProduct<Result, XCAT, CppTypeFor<XCAT, XKIND>,
+                CppTypeFor<YCAT, YKIND>>(x, y, terminator);
           }
         }
         terminator.Crash(
@@ -73,127 +105,76 @@ struct DotProduct {
   }
 };
 
-template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
-class NumericAccumulator {
-public:
-  using Result = RESULT;
-  NumericAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
-  void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
-    if constexpr (XCAT == TypeCategory::Complex) {
-      sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) *
-          static_cast<Result>(*y_.Element<YT>(&yAt));
-    } else {
-      sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) *
-          static_cast<Result>(*y_.Element<YT>(&yAt));
-    }
-  }
-  Result GetResult() const { return sum_; }
-
-private:
-  const Descriptor &x_, &y_;
-  Result sum_{0};
-};
-
-template <typename, TypeCategory, typename XT, typename YT>
-class LogicalAccumulator {
-public:
-  using Result = bool;
-  LogicalAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
-  void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
-    result_ = result_ ||
-        (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
-  }
-  bool GetResult() const { return result_; }
-
-private:
-  const Descriptor &x_, &y_;
-  bool result_{false};
-};
-
 extern "C" {
 std::int8_t RTNAME(DotProductInteger1)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
 }
 std::int16_t RTNAME(DotProductInteger2)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
 }
 std::int32_t RTNAME(DotProductInteger4)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
 }
 std::int64_t RTNAME(DotProductInteger8)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Integer, 8, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
 }
 #ifdef __SIZEOF_INT128__
 common::int128_t RTNAME(DotProductInteger16)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Integer, 16, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
 }
 #endif
 
 // TODO: REAL/COMPLEX(2 & 3)
 float RTNAME(DotProductReal4)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Real, 8, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
 }
 double RTNAME(DotProductReal8)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Real, 8, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
 }
 #if LONG_DOUBLE == 80
 long double RTNAME(DotProductReal10)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Real, 10, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
 }
 #elif LONG_DOUBLE == 128
 long double RTNAME(DotProductReal16)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Real, 16, NumericAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
 }
 #endif
 
 void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  auto z{DotProduct<TypeCategory::Complex, 8, NumericAccumulator>{}(
-      x, y, source, line)};
+  auto z{DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line)};
   result = std::complex<float>{
       static_cast<float>(z.real()), static_cast<float>(z.imag())};
 }
 void RTNAME(CppDotProductComplex8)(std::complex<double> &result,
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  result = DotProduct<TypeCategory::Complex, 8, NumericAccumulator>{}(
-      x, y, source, line);
+  result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
 }
 #if LONG_DOUBLE == 80
 void RTNAME(CppDotProductComplex10)(std::complex<long double> &result,
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  result = DotProduct<TypeCategory::Complex, 10, NumericAccumulator>{}(
-      x, y, source, line);
+  result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
 }
 #elif LONG_DOUBLE == 128
 void RTNAME(CppDotProductComplex16)(std::complex<long double> &result,
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  result = DotProduct<TypeCategory::Complex, 16, NumericAccumulator>{}(
-      x, y, source, line);
+  result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
 }
 #endif
 
 bool RTNAME(DotProductLogical)(
     const Descriptor &x, const Descriptor &y, const char *source, int line) {
-  return DotProduct<TypeCategory::Logical, 1, LogicalAccumulator>{}(
-      x, y, source, line);
+  return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
 }
 } // extern "C"
 } // namespace Fortran::runtime

diff  --git a/flang/runtime/matmul.cpp b/flang/runtime/matmul.cpp
new file mode 100644
index 0000000000000..3d10ca0a31c6b
--- /dev/null
+++ b/flang/runtime/matmul.cpp
@@ -0,0 +1,220 @@
+//===-- runtime/matmul.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// Implements all forms of MATMUL (Fortran 2018 16.9.124)
+//
+// There are two main entry points; one establishes a descriptor for the
+// result and allocates it, and the other expects a result descriptor that
+// points to existing storage.
+//
+// This implementation must handle all combinations of numeric types and
+// kinds (100 - 165 cases depending on the target), plus all combinations
+// of logical kinds (16).  A single template undergoes many instantiations
+// to cover all of the valid possibilities.
+//
+// Places where BLAS routines could be called are marked as TODO items.
+
+#include "matmul.h"
+#include "cpp-type.h"
+#include "descriptor.h"
+#include "terminator.h"
+#include "tools.h"
+
+namespace Fortran::runtime {
+
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+class Accumulator {
+public:
+  // Accumulate floating-point results in (at least) double precision
+  using Result = CppTypeFor<RCAT,
+      RCAT == TypeCategory::Real || RCAT == TypeCategory::Complex
+          ? std::max(RKIND, static_cast<int>(sizeof(double)))
+          : RKIND>;
+  Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
+  void Accumulate(const SubscriptValue xAt[], const SubscriptValue yAt[]) {
+    if constexpr (RCAT == TypeCategory::Logical) {
+      sum_ = sum_ ||
+          (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt));
+    } else {
+      sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) *
+          static_cast<Result>(*y_.Element<YT>(yAt));
+    }
+  }
+  Result GetResult() const { return sum_; }
+
+private:
+  const Descriptor &x_, &y_;
+  Result sum_{};
+};
+
+// Implements an instance of MATMUL for given argument types.
+template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
+    typename YT>
+static inline void DoMatmul(
+    std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result,
+    const Descriptor &x, const Descriptor &y, Terminator &terminator) {
+  int xRank{x.rank()};
+  int yRank{y.rank()};
+  int resRank{xRank + yRank - 2};
+  if (xRank * yRank != 2 * resRank) {
+    terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank);
+  }
+  SubscriptValue extent[2]{
+      xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(),
+      resRank == 2 ? y.GetDimension(1).Extent() : 0};
+  if constexpr (IS_ALLOCATING) {
+    result.Establish(
+        RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable);
+    for (int j{0}; j < resRank; ++j) {
+      result.GetDimension(j).SetBounds(1, extent[j]);
+    }
+    if (int stat{result.Allocate()}) {
+      terminator.Crash(
+          "MATMUL: could not allocate memory for result; STAT=%d", stat);
+    }
+  } else {
+    RUNTIME_CHECK(terminator, resRank == result.rank());
+    RUNTIME_CHECK(terminator, result.type() == (TypeCode{RCAT, RKIND}));
+    RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]);
+    RUNTIME_CHECK(terminator,
+        resRank == 1 || result.GetDimension(1).Extent() == extent[1]);
+  }
+  using WriteResult =
+      CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
+          RKIND>;
+  SubscriptValue n{x.GetDimension(xRank - 1).Extent()};
+  if (n != y.GetDimension(0).Extent()) {
+    terminator.Crash("MATMUL: arrays do not conform (%jd != %jd)",
+        static_cast<std::intmax_t>(n),
+        static_cast<std::intmax_t>(y.GetDimension(0).Extent()));
+  }
+  SubscriptValue xAt[2], yAt[2], resAt[2];
+  x.GetLowerBounds(xAt);
+  y.GetLowerBounds(yAt);
+  result.GetLowerBounds(resAt);
+  if (resRank == 2) { // M*M -> M
+    if constexpr (std::is_same_v<XT, YT>) {
+      if constexpr (std::is_same_v<XT, float>) {
+        // TODO: call BLAS-3 SGEMM
+      } else if constexpr (std::is_same_v<XT, double>) {
+        // TODO: call BLAS-3 DGEMM
+      } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+        // TODO: call BLAS-3 CGEMM
+      } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+        // TODO: call BLAS-3 ZGEMM
+      }
+    }
+    SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]};
+    for (SubscriptValue i{0}; i < extent[0]; ++i) {
+      for (SubscriptValue j{0}; j < extent[1]; ++j) {
+        Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
+        yAt[1] = y1 + j;
+        for (SubscriptValue k{0}; k < n; ++k) {
+          xAt[1] = x1 + k;
+          yAt[0] = y0 + k;
+          accumulator.Accumulate(xAt, yAt);
+        }
+        resAt[1] = res1 + j;
+        *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
+      }
+      ++resAt[0];
+      ++xAt[0];
+    }
+  } else {
+    if constexpr (std::is_same_v<XT, YT>) {
+      if constexpr (std::is_same_v<XT, float>) {
+        // TODO: call BLAS-2 SGEMV
+      } else if constexpr (std::is_same_v<XT, double>) {
+        // TODO: call BLAS-2 DGEMV
+      } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+        // TODO: call BLAS-2 CGEMV
+      } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
+        // TODO: call BLAS-2 ZGEMV
+      }
+    }
+    if (xRank == 2) { // M*V -> V
+      SubscriptValue x1{xAt[1]}, y0{yAt[0]};
+      for (SubscriptValue j{0}; j < extent[0]; ++j) {
+        Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
+        for (SubscriptValue k{0}; k < n; ++k) {
+          xAt[1] = x1 + k;
+          yAt[0] = y0 + k;
+          accumulator.Accumulate(xAt, yAt);
+        }
+        *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
+        ++resAt[0];
+        ++xAt[0];
+      }
+    } else { // V*M -> V
+      SubscriptValue x0{xAt[0]}, y0{yAt[0]};
+      for (SubscriptValue j{0}; j < extent[0]; ++j) {
+        Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
+        for (SubscriptValue k{0}; k < n; ++k) {
+          xAt[0] = x0 + k;
+          yAt[0] = y0 + k;
+          accumulator.Accumulate(xAt, yAt);
+        }
+        *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
+        ++resAt[0];
+        ++yAt[1];
+      }
+    }
+  }
+}
+
+// Maps the dynamic type information from the arguments' descriptors
+// to the right instantiation of DoMatmul() for valid combinations of
+// types.
+template <bool IS_ALLOCATING> struct Matmul {
+  using ResultDescriptor =
+      std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
+  template <TypeCategory XCAT, int XKIND> struct MM1 {
+    template <TypeCategory YCAT, int YKIND> struct MM2 {
+      void operator()(ResultDescriptor &result, const Descriptor &x,
+          const Descriptor &y, Terminator &terminator) const {
+        if constexpr (constexpr auto resultType{
+                          GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
+          if constexpr (common::IsNumericTypeCategory(resultType->first) ||
+              resultType->first == TypeCategory::Logical) {
+            return DoMatmul<IS_ALLOCATING, resultType->first,
+                resultType->second, CppTypeFor<XCAT, XKIND>,
+                CppTypeFor<YCAT, YKIND>>(result, x, y, terminator);
+          }
+        }
+        terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
+            static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
+      }
+    };
+    void operator()(ResultDescriptor &result, const Descriptor &x,
+        const Descriptor &y, Terminator &terminator, TypeCategory yCat,
+        int yKind) const {
+      ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator);
+    }
+  };
+  void operator()(ResultDescriptor &result, const Descriptor &x,
+      const Descriptor &y, const char *sourceFile, int line) const {
+    Terminator terminator{sourceFile, line};
+    auto xCatKind{x.type().GetCategoryAndKind()};
+    auto yCatKind{y.type().GetCategoryAndKind()};
+    RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
+    ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result,
+        x, y, terminator, yCatKind->first, yCatKind->second);
+  }
+};
+
+extern "C" {
+void RTNAME(Matmul)(Descriptor &result, const Descriptor &x,
+    const Descriptor &y, const char *sourceFile, int line) {
+  Matmul<true>{}(result, x, y, sourceFile, line);
+}
+void RTNAME(MatmulDirect)(const Descriptor &result, const Descriptor &x,
+    const Descriptor &y, const char *sourceFile, int line) {
+  Matmul<false>{}(result, x, y, sourceFile, line);
+}
+} // extern "C"
+} // namespace Fortran::runtime

diff  --git a/flang/runtime/matmul.h b/flang/runtime/matmul.h
new file mode 100644
index 0000000000000..8334d6670a1b0
--- /dev/null
+++ b/flang/runtime/matmul.h
@@ -0,0 +1,29 @@
+//===-- runtime/matmul.h ----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// API for the transformational intrinsic function MATMUL.
+
+#ifndef FORTRAN_RUNTIME_MATMUL_H_
+#define FORTRAN_RUNTIME_MATMUL_H_
+#include "entry-names.h"
+namespace Fortran::runtime {
+class Descriptor;
+extern "C" {
+
+// The most general MATMUL.  All type and shape information is taken from the
+// arguments' descriptors, and the result is dynamically allocated.
+void RTNAME(Matmul)(Descriptor &, const Descriptor &, const Descriptor &,
+    const char *sourceFile = nullptr, int line = 0);
+
+// A non-allocating variant; the result's descriptor must be established
+// and have a valid base address.
+void RTNAME(MatmulDirect)(const Descriptor &, const Descriptor &,
+    const Descriptor &, const char *sourceFile = nullptr, int line = 0);
+} // extern "C"
+} // namespace Fortran::runtime
+#endif // FORTRAN_RUNTIME_MATMUL_H_

diff  --git a/flang/runtime/reduction.h b/flang/runtime/reduction.h
index cec30843da6d5..379fcb85cd1c5 100644
--- a/flang/runtime/reduction.h
+++ b/flang/runtime/reduction.h
@@ -7,9 +7,6 @@
 //===----------------------------------------------------------------------===//
 
 // Defines the API for the reduction transformational intrinsic functions.
-// (Except the complex-valued DOT_PRODUCT and the complex-valued total reduction
-// forms of SUM & PRODUCT; the API for those is in complex-reduction.h so that
-// C's _Complex can be used for their return types.)
 
 #ifndef FORTRAN_RUNTIME_REDUCTION_H_
 #define FORTRAN_RUNTIME_REDUCTION_H_
@@ -36,10 +33,10 @@ extern "C" {
 // results in a caller-supplied descriptor, which is assumed to
 // be large enough.
 //
-// Complex-valued SUM and PRODUCT reductions have their API
-// entry points defined in complex-reduction.h; these are C wrappers
-// around C++ implementations so as to keep usage of C's _Complex
-// types out of C++ code.
+// Complex-valued SUM and PRODUCT reductions and complex-valued
+// DOT_PRODUCT have their API entry points defined in complex-reduction.h;
+// these here are C wrappers around C++ implementations so as to keep
+// usage of C's _Complex types out of C++ code.
 
 // SUM()
 

diff  --git a/flang/unittests/RuntimeGTest/CMakeLists.txt b/flang/unittests/RuntimeGTest/CMakeLists.txt
index cad827a8a9668..3d45cf6dc877b 100644
--- a/flang/unittests/RuntimeGTest/CMakeLists.txt
+++ b/flang/unittests/RuntimeGTest/CMakeLists.txt
@@ -2,6 +2,7 @@ add_flang_unittest(FlangRuntimeTests
   CharacterTest.cpp
   CrashHandlerFixture.cpp
   Format.cpp
+  Matmul.cpp
   MiscIntrinsic.cpp
   Namelist.cpp
   Numeric.cpp

diff  --git a/flang/unittests/RuntimeGTest/Matmul.cpp b/flang/unittests/RuntimeGTest/Matmul.cpp
new file mode 100644
index 0000000000000..ae9e7a84236c8
--- /dev/null
+++ b/flang/unittests/RuntimeGTest/Matmul.cpp
@@ -0,0 +1,98 @@
+//===-- flang/unittests/RuntimeGTest/Matmul.cpp---- -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../../runtime/matmul.h"
+#include "gtest/gtest.h"
+#include "tools.h"
+#include "../../runtime/allocatable.h"
+#include "../../runtime/cpp-type.h"
+#include "../../runtime/descriptor.h"
+#include "../../runtime/type-code.h"
+
+using namespace Fortran::runtime;
+using Fortran::common::TypeCategory;
+
+TEST(Matmul, Basic) {
+  // X 0 2 4   Y 6  9   V -1 -2
+  //   1 3 5     7 10
+  //             8 11
+  auto x{MakeArray<TypeCategory::Integer, 4>(
+      std::vector<int>{2, 3}, std::vector<std::int32_t>{0, 1, 2, 3, 4, 5})};
+  auto y{MakeArray<TypeCategory::Integer, 2>(
+      std::vector<int>{3, 2}, std::vector<std::int16_t>{6, 7, 8, 9, 10, 11})};
+  auto v{MakeArray<TypeCategory::Integer, 8>(
+      std::vector<int>{2}, std::vector<std::int64_t>{-1, -2})};
+  StaticDescriptor<2> statDesc;
+  Descriptor &result{statDesc.descriptor()};
+
+  RTNAME(Matmul)(result, *x, *y, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 2);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+  EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+
+  std::memset(
+      result.raw().base_addr, 0, result.Elements() * result.ElementBytes());
+  result.GetDimension(0).SetLowerBound(0);
+  result.GetDimension(1).SetLowerBound(2);
+  RTNAME(MatmulDirect)(result, *x, *y, __FILE__, __LINE__);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+  result.Destroy();
+
+  RTNAME(Matmul)(result, *v, *x, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 1);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
+  result.Destroy();
+
+  RTNAME(Matmul)(result, *y, *v, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 1);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
+  result.Destroy();
+
+  // X F F T  Y F T
+  //   F T T    F T
+  //            F F
+  auto xLog{MakeArray<TypeCategory::Logical, 1>(std::vector<int>{2, 3},
+      std::vector<std::uint8_t>{false, false, false, true, true, false})};
+  auto yLog{MakeArray<TypeCategory::Logical, 2>(std::vector<int>{3, 2},
+      std::vector<std::uint16_t>{false, false, false, true, true, false})};
+  RTNAME(Matmul)(result, *xLog, *yLog, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 2);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+  EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
+  EXPECT_FALSE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
+  EXPECT_FALSE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
+  EXPECT_FALSE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
+  EXPECT_TRUE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
+}


        


More information about the flang-commits mailing list