[flang-commits] [flang] [flang][runtime] Split MATMUL[_TRANSPOSE] into separate entries. (PR #97406)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Tue Jul 2 05:06:48 PDT 2024


https://github.com/vzakhari created https://github.com/llvm/llvm-project/pull/97406

Device compilation is much faster for separate MATMUL[_TRANPOSE]
entries than for a single one that covers all data types.
The lowering changes and the removal of the generic entries will follow.


>From 4a594ea790d13e0e641bfd576ba82e3736483859 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 2 Jul 2024 04:56:55 -0700
Subject: [PATCH] [flang][runtime] Split MATMUL[_TRANSPOSE] into separate
 entries.

Device compilation is much faster for separate MATMUL[_TRANPOSE]
entries than for a single one that covers all data types.
The lowering changes and the removal of the generic entries will follow.
---
 .../flang/Runtime/matmul-instances.inc        | 261 ++++++++++++++++++
 .../include/flang/Runtime/matmul-transpose.h  |  17 ++
 flang/include/flang/Runtime/matmul.h          |  17 ++
 flang/runtime/matmul-transpose.cpp            |  42 +++
 flang/runtime/matmul.cpp                      |  50 +++-
 flang/unittests/Runtime/Matmul.cpp            | 121 ++++++++
 flang/unittests/Runtime/MatmulTranspose.cpp   | 140 ++++++++++
 7 files changed, 646 insertions(+), 2 deletions(-)
 create mode 100644 flang/include/flang/Runtime/matmul-instances.inc

diff --git a/flang/include/flang/Runtime/matmul-instances.inc b/flang/include/flang/Runtime/matmul-instances.inc
new file mode 100644
index 0000000000000..970b03339cd5e
--- /dev/null
+++ b/flang/include/flang/Runtime/matmul-instances.inc
@@ -0,0 +1,261 @@
+//===-- include/flang/Runtime/matmul-instances.inc --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Helper macros to instantiate MATMUL/MATMUL_TRANSPOSE definitions
+// for different data types of the input arguments.
+//===----------------------------------------------------------------------===//
+
+#ifndef MATMUL_INSTANCE
+#error "Define MATMUL_INSTANCE before including this file"
+#endif
+
+#ifndef MATMUL_DIRECT_INSTANCE
+#error "Define MATMUL_DIRECT_INSTANCE before including this file"
+#endif
+
+// clang-format off
+
+#define FOREACH_MATMUL_TYPE_PAIR(macro)         \
+  macro(Integer, 1, Integer, 1)                 \
+  macro(Integer, 1, Integer, 2)                 \
+  macro(Integer, 1, Integer, 4)                 \
+  macro(Integer, 1, Integer, 8)                 \
+  macro(Integer, 2, Integer, 1)                 \
+  macro(Integer, 2, Integer, 2)                 \
+  macro(Integer, 2, Integer, 4)                 \
+  macro(Integer, 2, Integer, 8)                 \
+  macro(Integer, 4, Integer, 1)                 \
+  macro(Integer, 4, Integer, 2)                 \
+  macro(Integer, 4, Integer, 4)                 \
+  macro(Integer, 4, Integer, 8)                 \
+  macro(Integer, 8, Integer, 1)                 \
+  macro(Integer, 8, Integer, 2)                 \
+  macro(Integer, 8, Integer, 4)                 \
+  macro(Integer, 8, Integer, 8)                 \
+  macro(Integer, 1, Real, 4)                    \
+  macro(Integer, 1, Real, 8)                    \
+  macro(Integer, 2, Real, 4)                    \
+  macro(Integer, 2, Real, 8)                    \
+  macro(Integer, 4, Real, 4)                    \
+  macro(Integer, 4, Real, 8)                    \
+  macro(Integer, 8, Real, 4)                    \
+  macro(Integer, 8, Real, 8)                    \
+  macro(Integer, 1, Complex, 4)                 \
+  macro(Integer, 1, Complex, 8)                 \
+  macro(Integer, 2, Complex, 4)                 \
+  macro(Integer, 2, Complex, 8)                 \
+  macro(Integer, 4, Complex, 4)                 \
+  macro(Integer, 4, Complex, 8)                 \
+  macro(Integer, 8, Complex, 4)                 \
+  macro(Integer, 8, Complex, 8)                 \
+  macro(Real, 4, Integer, 1)                    \
+  macro(Real, 4, Integer, 2)                    \
+  macro(Real, 4, Integer, 4)                    \
+  macro(Real, 4, Integer, 8)                    \
+  macro(Real, 8, Integer, 1)                    \
+  macro(Real, 8, Integer, 2)                    \
+  macro(Real, 8, Integer, 4)                    \
+  macro(Real, 8, Integer, 8)                    \
+  macro(Real, 4, Real, 4)                       \
+  macro(Real, 4, Real, 8)                       \
+  macro(Real, 8, Real, 4)                       \
+  macro(Real, 8, Real, 8)                       \
+  macro(Real, 4, Complex, 4)                    \
+  macro(Real, 4, Complex, 8)                    \
+  macro(Real, 8, Complex, 4)                    \
+  macro(Real, 8, Complex, 8)                    \
+  macro(Complex, 4, Integer, 1)                 \
+  macro(Complex, 4, Integer, 2)                 \
+  macro(Complex, 4, Integer, 4)                 \
+  macro(Complex, 4, Integer, 8)                 \
+  macro(Complex, 8, Integer, 1)                 \
+  macro(Complex, 8, Integer, 2)                 \
+  macro(Complex, 8, Integer, 4)                 \
+  macro(Complex, 8, Integer, 8)                 \
+  macro(Complex, 4, Real, 4)                    \
+  macro(Complex, 4, Real, 8)                    \
+  macro(Complex, 8, Real, 4)                    \
+  macro(Complex, 8, Real, 8)                    \
+  macro(Complex, 4, Complex, 4)                 \
+  macro(Complex, 4, Complex, 8)                 \
+  macro(Complex, 8, Complex, 4)                 \
+  macro(Complex, 8, Complex, 8)                 \
+
+FOREACH_MATMUL_TYPE_PAIR(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
+
+#if defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
+#define FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(macro)      \
+  macro(Integer, 16, Integer, 1)                        \
+  macro(Integer, 16, Integer, 2)                        \
+  macro(Integer, 16, Integer, 4)                        \
+  macro(Integer, 16, Integer, 8)                        \
+  macro(Integer, 16, Integer, 16)                       \
+  macro(Integer, 16, Real, 4)                           \
+  macro(Integer, 16, Real, 8)                           \
+  macro(Integer, 16, Complex, 4)                        \
+  macro(Integer, 16, Complex, 8)                        \
+  macro(Real, 4, Integer, 16)                           \
+  macro(Real, 8, Integer, 16)                           \
+  macro(Complex, 4, Integer, 16)                        \
+  macro(Complex, 8, Integer, 16)                        \
+
+FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_DIRECT_INSTANCE)
+
+#if LDBL_MANT_DIG == 64
+MATMUL_INSTANCE(Integer, 16, Real, 10)
+MATMUL_INSTANCE(Integer, 16, Complex, 10)
+MATMUL_INSTANCE(Real, 10, Integer, 16)
+MATMUL_INSTANCE(Complex, 10, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Real, 10)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 10)
+MATMUL_DIRECT_INSTANCE(Real, 10, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 10, Integer, 16)
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+MATMUL_INSTANCE(Integer, 16, Real, 16)
+MATMUL_INSTANCE(Integer, 16, Complex, 16)
+MATMUL_INSTANCE(Real, 16, Integer, 16)
+MATMUL_INSTANCE(Complex, 16, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Real, 16)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 16)
+MATMUL_DIRECT_INSTANCE(Real, 16, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
+#endif
+#endif // defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
+
+#if LDBL_MANT_DIG == 64
+#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(macro)         \
+  macro(Integer, 1, Real, 10)                               \
+  macro(Integer, 1, Complex, 10)                            \
+  macro(Integer, 2, Real, 10)                               \
+  macro(Integer, 2, Complex, 10)                            \
+  macro(Integer, 4, Real, 10)                               \
+  macro(Integer, 4, Complex, 10)                            \
+  macro(Integer, 8, Real, 10)                               \
+  macro(Integer, 8, Complex, 10)                            \
+  macro(Real, 4, Real, 10)                                  \
+  macro(Real, 4, Complex, 10)                               \
+  macro(Real, 8, Real, 10)                                  \
+  macro(Real, 8, Complex, 10)                               \
+  macro(Real, 10, Integer, 1)                               \
+  macro(Real, 10, Integer, 2)                               \
+  macro(Real, 10, Integer, 4)                               \
+  macro(Real, 10, Integer, 8)                               \
+  macro(Real, 10, Real, 4)                                  \
+  macro(Real, 10, Real, 8)                                  \
+  macro(Real, 10, Real, 10)                                 \
+  macro(Real, 10, Complex, 4)                               \
+  macro(Real, 10, Complex, 8)                               \
+  macro(Real, 10, Complex, 10)                              \
+  macro(Complex, 4, Real, 10)                               \
+  macro(Complex, 4, Complex, 10)                            \
+  macro(Complex, 8, Real, 10)                               \
+  macro(Complex, 8, Complex, 10)                            \
+  macro(Complex, 10, Integer, 1)                            \
+  macro(Complex, 10, Integer, 2)                            \
+  macro(Complex, 10, Integer, 4)                            \
+  macro(Complex, 10, Integer, 8)                            \
+  macro(Complex, 10, Real, 4)                               \
+  macro(Complex, 10, Real, 8)                               \
+  macro(Complex, 10, Real, 10)                              \
+  macro(Complex, 10, Complex, 4)                            \
+  macro(Complex, 10, Complex, 8)                            \
+  macro(Complex, 10, Complex, 10)                           \
+
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_DIRECT_INSTANCE)
+
+#if HAS_FLOAT128
+MATMUL_INSTANCE(Real, 10, Real, 16)
+MATMUL_INSTANCE(Real, 10, Complex, 16)
+MATMUL_INSTANCE(Real, 16, Real, 10)
+MATMUL_INSTANCE(Real, 16, Complex, 10)
+MATMUL_INSTANCE(Complex, 10, Real, 16)
+MATMUL_INSTANCE(Complex, 10, Complex, 16)
+MATMUL_INSTANCE(Complex, 16, Real, 10)
+MATMUL_INSTANCE(Complex, 16, Complex, 10)
+MATMUL_DIRECT_INSTANCE(Real, 10, Real, 16)
+MATMUL_DIRECT_INSTANCE(Real, 10, Complex, 16)
+MATMUL_DIRECT_INSTANCE(Real, 16, Real, 10)
+MATMUL_DIRECT_INSTANCE(Real, 16, Complex, 10)
+MATMUL_DIRECT_INSTANCE(Complex, 10, Real, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 10, Complex, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 16, Real, 10)
+MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
+#endif
+#endif // LDBL_MANT_DIG == 64
+
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(macro)         \
+  macro(Integer, 1, Real, 16)                               \
+  macro(Integer, 1, Complex, 16)                            \
+  macro(Integer, 2, Real, 16)                               \
+  macro(Integer, 2, Complex, 16)                            \
+  macro(Integer, 4, Real, 16)                               \
+  macro(Integer, 4, Complex, 16)                            \
+  macro(Integer, 8, Real, 16)                               \
+  macro(Integer, 8, Complex, 16)                            \
+  macro(Real, 4, Real, 16)                                  \
+  macro(Real, 4, Complex, 16)                               \
+  macro(Real, 8, Real, 16)                                  \
+  macro(Real, 8, Complex, 16)                               \
+  macro(Real, 16, Integer, 1)                               \
+  macro(Real, 16, Integer, 2)                               \
+  macro(Real, 16, Integer, 4)                               \
+  macro(Real, 16, Integer, 8)                               \
+  macro(Real, 16, Real, 4)                                  \
+  macro(Real, 16, Real, 8)                                  \
+  macro(Real, 16, Real, 16)                                 \
+  macro(Real, 16, Complex, 4)                               \
+  macro(Real, 16, Complex, 8)                               \
+  macro(Real, 16, Complex, 16)                              \
+  macro(Complex, 4, Real, 16)                               \
+  macro(Complex, 4, Complex, 16)                            \
+  macro(Complex, 8, Real, 16)                               \
+  macro(Complex, 8, Complex, 16)                            \
+  macro(Complex, 16, Integer, 1)                            \
+  macro(Complex, 16, Integer, 2)                            \
+  macro(Complex, 16, Integer, 4)                            \
+  macro(Complex, 16, Integer, 8)                            \
+  macro(Complex, 16, Real, 4)                               \
+  macro(Complex, 16, Real, 8)                               \
+  macro(Complex, 16, Real, 16)                              \
+  macro(Complex, 16, Complex, 4)                            \
+  macro(Complex, 16, Complex, 8)                            \
+  macro(Complex, 16, Complex, 16)                           \
+
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_DIRECT_INSTANCE)
+#endif // LDBL_MANT_DIG == 113 || HAS_FLOAT128
+
+#define FOREACH_MATMUL_LOGICAL_TYPE_PAIR(macro) \
+  macro(Logical, 1, Logical, 1)                 \
+  macro(Logical, 1, Logical, 2)                 \
+  macro(Logical, 1, Logical, 4)                 \
+  macro(Logical, 1, Logical, 8)                 \
+  macro(Logical, 2, Logical, 1)                 \
+  macro(Logical, 2, Logical, 2)                 \
+  macro(Logical, 2, Logical, 4)                 \
+  macro(Logical, 2, Logical, 8)                 \
+  macro(Logical, 4, Logical, 1)                 \
+  macro(Logical, 4, Logical, 2)                 \
+  macro(Logical, 4, Logical, 4)                 \
+  macro(Logical, 4, Logical, 8)                 \
+  macro(Logical, 8, Logical, 1)                 \
+  macro(Logical, 8, Logical, 2)                 \
+  macro(Logical, 8, Logical, 4)                 \
+  macro(Logical, 8, Logical, 8)                 \
+
+FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_INSTANCE)
+FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
+
+#undef MATMUL_INSTANCE
+#undef MATMUL_DIRECT_INSTANCE
+
+// clang-format on
diff --git a/flang/include/flang/Runtime/matmul-transpose.h b/flang/include/flang/Runtime/matmul-transpose.h
index 5eb5896972e0f..d0a5005a1a8bd 100644
--- a/flang/include/flang/Runtime/matmul-transpose.h
+++ b/flang/include/flang/Runtime/matmul-transpose.h
@@ -10,6 +10,8 @@
 
 #ifndef FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
 #define FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
+#include "flang/Common/float128.h"
+#include "flang/Common/uint128.h"
 #include "flang/Runtime/entry-names.h"
 namespace Fortran::runtime {
 class Descriptor;
@@ -25,6 +27,21 @@ void RTDECL(MatmulTranspose)(Descriptor &, const Descriptor &,
 // and have a valid base address.
 void RTDECL(MatmulTransposeDirect)(const Descriptor &, const Descriptor &,
     const Descriptor &, const char *sourceFile = nullptr, int line = 0);
+
+// MATMUL(TRANSPOSE()) versions specialized by the categories of the operand
+// types. The KIND and shape information is taken from the argument's
+// descriptors.
+#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+  void RTDECL(MatmulTranspose##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+      const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+      int line);
+#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+  void RTDECL(MatmulTransposeDirect##XCAT##XKIND##YCAT##YKIND)( \
+      Descriptor & result, const Descriptor &x, const Descriptor &y, \
+      const char *sourceFile, int line);
+
+#include "matmul-instances.inc"
+
 } // extern "C"
 } // namespace Fortran::runtime
 #endif // FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
diff --git a/flang/include/flang/Runtime/matmul.h b/flang/include/flang/Runtime/matmul.h
index 40581d44de9e2..1a5e39eb8813f 100644
--- a/flang/include/flang/Runtime/matmul.h
+++ b/flang/include/flang/Runtime/matmul.h
@@ -10,6 +10,8 @@
 
 #ifndef FORTRAN_RUNTIME_MATMUL_H_
 #define FORTRAN_RUNTIME_MATMUL_H_
+#include "flang/Common/float128.h"
+#include "flang/Common/uint128.h"
 #include "flang/Runtime/entry-names.h"
 namespace Fortran::runtime {
 class Descriptor;
@@ -24,6 +26,21 @@ void RTDECL(Matmul)(Descriptor &, const Descriptor &, const Descriptor &,
 // and have a valid base address.
 void RTDECL(MatmulDirect)(const Descriptor &, const Descriptor &,
     const Descriptor &, const char *sourceFile = nullptr, int line = 0);
+
+// MATMUL versions specialized by the categories of the operand types.
+// The KIND and shape information is taken from the argument's
+// descriptors.
+#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+  void RTDECL(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+      const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+      int line);
+#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+  void RTDECL(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+      const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+      int line);
+
+#include "matmul-instances.inc"
+
 } // extern "C"
 } // namespace Fortran::runtime
 #endif // FORTRAN_RUNTIME_MATMUL_H_
diff --git a/flang/runtime/matmul-transpose.cpp b/flang/runtime/matmul-transpose.cpp
index a12d188266f7c..1c998fa8cf6c1 100644
--- a/flang/runtime/matmul-transpose.cpp
+++ b/flang/runtime/matmul-transpose.cpp
@@ -384,6 +384,30 @@ template <bool IS_ALLOCATING> struct MatmulTranspose {
         x, y, terminator, yCatKind->first, yCatKind->second);
   }
 };
+
+template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT,
+    int YKIND>
+struct MatmulTransposeHelper {
+  using ResultDescriptor =
+      std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
+  RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
+      const Descriptor &y, const char *sourceFile, int line) const {
+    Terminator terminator{sourceFile, line};
+    auto xCatKind{x.type().GetCategoryAndKind()};
+    auto yCatKind{y.type().GetCategoryAndKind()};
+    RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
+    RUNTIME_CHECK(terminator, xCatKind->first == XCAT);
+    RUNTIME_CHECK(terminator, yCatKind->first == YCAT);
+    if constexpr (constexpr auto resultType{
+                      GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
+      return DoMatmulTranspose<IS_ALLOCATING, resultType->first,
+          resultType->second, CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>(
+          result, x, y, terminator);
+    }
+    terminator.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))",
+        static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
+  }
+};
 } // namespace
 
 namespace Fortran::runtime {
@@ -399,6 +423,24 @@ void RTDEF(MatmulTransposeDirect)(const Descriptor &result, const Descriptor &x,
   MatmulTranspose<false>{}(result, x, y, sourceFile, line);
 }
 
+#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+  void RTDEF(MatmulTranspose##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+      const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+      int line) { \
+    MatmulTransposeHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
+        YKIND>{}(result, x, y, sourceFile, line); \
+  }
+
+#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+  void RTDEF(MatmulTransposeDirect##XCAT##XKIND##YCAT##YKIND)( \
+      Descriptor & result, const Descriptor &x, const Descriptor &y, \
+      const char *sourceFile, int line) { \
+    MatmulTransposeHelper<false, TypeCategory::XCAT, XKIND, \
+        TypeCategory::YCAT, YKIND>{}(result, x, y, sourceFile, line); \
+  }
+
+#include "flang/Runtime/matmul-instances.inc"
+
 RT_EXT_API_GROUP_END
 } // extern "C"
 } // namespace Fortran::runtime
diff --git a/flang/runtime/matmul.cpp b/flang/runtime/matmul.cpp
index 8f9b50a549e1f..504d1aa4dc4a4 100644
--- a/flang/runtime/matmul.cpp
+++ b/flang/runtime/matmul.cpp
@@ -28,7 +28,8 @@
 #include "flang/Runtime/descriptor.h"
 #include <cstring>
 
-namespace Fortran::runtime {
+namespace {
+using namespace Fortran::runtime;
 
 // Suppress the warnings about calling __host__-only std::complex operators,
 // defined in C++ STD header files, from __device__ code.
@@ -455,7 +456,8 @@ template <bool IS_ALLOCATING> struct Matmul {
           Terminator &terminator) const {
         if constexpr (constexpr auto resultType{
                           GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
-          if constexpr (common::IsNumericTypeCategory(resultType->first) ||
+          if constexpr (Fortran::common::IsNumericTypeCategory(
+                            resultType->first) ||
               resultType->first == TypeCategory::Logical) {
             return DoMatmul<IS_ALLOCATING, resultType->first,
                 resultType->second, CppTypeFor<XCAT, XKIND>,
@@ -483,6 +485,32 @@ template <bool IS_ALLOCATING> struct Matmul {
   }
 };
 
+template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT,
+    int YKIND>
+struct MatmulHelper {
+  using ResultDescriptor =
+      std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
+  RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
+      const Descriptor &y, const char *sourceFile, int line) const {
+    Terminator terminator{sourceFile, line};
+    auto xCatKind{x.type().GetCategoryAndKind()};
+    auto yCatKind{y.type().GetCategoryAndKind()};
+    RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
+    RUNTIME_CHECK(terminator, xCatKind->first == XCAT);
+    RUNTIME_CHECK(terminator, yCatKind->first == YCAT);
+    if constexpr (constexpr auto resultType{
+                      GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
+      return DoMatmul<IS_ALLOCATING, resultType->first, resultType->second,
+          CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>(
+          result, x, y, terminator);
+    }
+    terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
+        static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
+  }
+};
+} // namespace
+
+namespace Fortran::runtime {
 extern "C" {
 RT_EXT_API_GROUP_BEGIN
 
@@ -495,6 +523,24 @@ void RTDEF(MatmulDirect)(const Descriptor &result, const Descriptor &x,
   Matmul<false>{}(result, x, y, sourceFile, line);
 }
 
+#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+  void RTDEF(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+      const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+      int line) { \
+    MatmulHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
+        YKIND>{}(result, x, y, sourceFile, line); \
+  }
+
+#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+  void RTDEF(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+      const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+      int line) { \
+    MatmulHelper<false, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
+        YKIND>{}(result, x, y, sourceFile, line); \
+  }
+
+#include "flang/Runtime/matmul-instances.inc"
+
 RT_EXT_API_GROUP_END
 } // extern "C"
 } // namespace Fortran::runtime
diff --git a/flang/unittests/Runtime/Matmul.cpp b/flang/unittests/Runtime/Matmul.cpp
index 1d6c5ccc609b4..226dbc5ae9eeb 100644
--- a/flang/unittests/Runtime/Matmul.cpp
+++ b/flang/unittests/Runtime/Matmul.cpp
@@ -63,6 +63,29 @@ TEST(Matmul, Basic) {
   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
   result.Destroy();
 
+  RTNAME(MatmulInteger4Integer2)(result, *x, *y, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 2);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+  EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+
+  std::memset(
+      result.raw().base_addr, 0, result.Elements() * result.ElementBytes());
+  result.GetDimension(0).SetLowerBound(0);
+  result.GetDimension(1).SetLowerBound(2);
+  RTNAME(MatmulDirectInteger4Integer2)(result, *x, *y, __FILE__, __LINE__);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+  result.Destroy();
+
   RTNAME(Matmul)(result, *v, *x, __FILE__, __LINE__);
   ASSERT_EQ(result.rank(), 1);
   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -73,6 +96,16 @@ TEST(Matmul, Basic) {
   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
   result.Destroy();
 
+  RTNAME(MatmulInteger8Integer4)(result, *v, *x, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 1);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
+  result.Destroy();
+
   RTNAME(Matmul)(result, *y, *v, __FILE__, __LINE__);
   ASSERT_EQ(result.rank(), 1);
   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -83,6 +116,16 @@ TEST(Matmul, Basic) {
   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
   result.Destroy();
 
+  RTNAME(MatmulInteger2Integer8)(result, *y, *v, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 1);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
+  result.Destroy();
+
   // Test non-contiguous sections.
   static constexpr int sectionRank{2};
   StaticDescriptor<sectionRank> sectionStaticDescriptorX2;
@@ -129,6 +172,19 @@ TEST(Matmul, Basic) {
   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
   result.Destroy();
 
+  RTNAME(MatmulInteger4Integer2)(result, sectionX2, *y, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 2);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+  EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+  result.Destroy();
+
   RTNAME(Matmul)(result, *x, sectionY2, __FILE__, __LINE__);
   ASSERT_EQ(result.rank(), 2);
   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -142,6 +198,19 @@ TEST(Matmul, Basic) {
   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
   result.Destroy();
 
+  RTNAME(MatmulInteger4Integer2)(result, *x, sectionY2, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 2);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+  EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+  result.Destroy();
+
   RTNAME(Matmul)(result, sectionX2, sectionY2, __FILE__, __LINE__);
   ASSERT_EQ(result.rank(), 2);
   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -155,6 +224,20 @@ TEST(Matmul, Basic) {
   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
   result.Destroy();
 
+  RTNAME(MatmulInteger4Integer2)
+  (result, sectionX2, sectionY2, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 2);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+  EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+  result.Destroy();
+
   RTNAME(Matmul)(result, *v, sectionX2, __FILE__, __LINE__);
   ASSERT_EQ(result.rank(), 1);
   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -165,6 +248,16 @@ TEST(Matmul, Basic) {
   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
   result.Destroy();
 
+  RTNAME(MatmulInteger8Integer4)(result, *v, sectionX2, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 1);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
+  result.Destroy();
+
   RTNAME(Matmul)(result, sectionY2, *v, __FILE__, __LINE__);
   ASSERT_EQ(result.rank(), 1);
   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -175,6 +268,16 @@ TEST(Matmul, Basic) {
   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
   result.Destroy();
 
+  RTNAME(MatmulInteger2Integer8)(result, sectionY2, *v, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 1);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
+  result.Destroy();
+
   // X F F T  Y F T
   //   F T T    F T
   //            F F
@@ -197,4 +300,22 @@ TEST(Matmul, Basic) {
       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
   EXPECT_TRUE(
       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
+  result.Destroy();
+
+  RTNAME(MatmulLogical1Logical2)(result, *xLog, *yLog, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 2);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+  EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
+  EXPECT_FALSE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
+  EXPECT_FALSE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
+  EXPECT_FALSE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
+  EXPECT_TRUE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
+  result.Destroy();
 }
diff --git a/flang/unittests/Runtime/MatmulTranspose.cpp b/flang/unittests/Runtime/MatmulTranspose.cpp
index fe946f6d5a201..391c2e1b144ea 100644
--- a/flang/unittests/Runtime/MatmulTranspose.cpp
+++ b/flang/unittests/Runtime/MatmulTranspose.cpp
@@ -69,6 +69,30 @@ TEST(MatmulTranspose, Basic) {
   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
   result.Destroy();
 
+  RTNAME(MatmulTransposeInteger4Integer2)(result, *x, *y, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 2);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+  EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+
+  std::memset(
+      result.raw().base_addr, 0, result.Elements() * result.ElementBytes());
+  result.GetDimension(0).SetLowerBound(0);
+  result.GetDimension(1).SetLowerBound(2);
+  RTNAME(MatmulTransposeDirectInteger4Integer2)
+  (result, *x, *y, __FILE__, __LINE__);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+  result.Destroy();
+
   RTNAME(MatmulTranspose)(result, *z, *v, __FILE__, __LINE__);
   ASSERT_EQ(result.rank(), 1);
   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -79,6 +103,16 @@ TEST(MatmulTranspose, Basic) {
   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
   result.Destroy();
 
+  RTNAME(MatmulTransposeInteger2Integer8)(result, *z, *v, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 1);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
+  result.Destroy();
+
   RTNAME(MatmulTranspose)(result, *m, *z, __FILE__, __LINE__);
   ASSERT_EQ(result.rank(), 2);
   ASSERT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -100,6 +134,27 @@ TEST(MatmulTranspose, Basic) {
   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(11), 19);
   result.Destroy();
 
+  RTNAME(MatmulTransposeInteger2Integer2)(result, *m, *z, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 2);
+  ASSERT_EQ(result.GetDimension(0).LowerBound(), 1);
+  ASSERT_EQ(result.GetDimension(0).UpperBound(), 4);
+  ASSERT_EQ(result.GetDimension(1).LowerBound(), 1);
+  ASSERT_EQ(result.GetDimension(1).UpperBound(), 3);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 2}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(0), 0);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(1), 9);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(2), 6);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(3), 15);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(4), 0);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(5), 10);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(6), 7);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(7), 17);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(8), 0);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(9), 11);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(10), 8);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(11), 19);
+  result.Destroy();
+
   // Test non-contiguous sections.
   static constexpr int sectionRank{2};
   StaticDescriptor<sectionRank> sectionStaticDescriptorX2;
@@ -162,6 +217,20 @@ TEST(MatmulTranspose, Basic) {
   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
   result.Destroy();
 
+  RTNAME(MatmulTransposeInteger4Integer2)
+  (result, sectionX2, *y, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 2);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+  EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+  result.Destroy();
+
   RTNAME(MatmulTranspose)(result, *x, sectionY2, __FILE__, __LINE__);
   ASSERT_EQ(result.rank(), 2);
   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -175,6 +244,20 @@ TEST(MatmulTranspose, Basic) {
   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
   result.Destroy();
 
+  RTNAME(MatmulTransposeInteger4Integer2)
+  (result, *x, sectionY2, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 2);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+  EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+  result.Destroy();
+
   RTNAME(MatmulTranspose)(result, sectionX2, sectionY2, __FILE__, __LINE__);
   ASSERT_EQ(result.rank(), 2);
   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -188,6 +271,20 @@ TEST(MatmulTranspose, Basic) {
   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
   result.Destroy();
 
+  RTNAME(MatmulTransposeInteger4Integer2)
+  (result, sectionX2, sectionY2, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 2);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+  EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+  result.Destroy();
+
   RTNAME(MatmulTranspose)(result, sectionZ2, *v, __FILE__, __LINE__);
   ASSERT_EQ(result.rank(), 1);
   EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -198,6 +295,17 @@ TEST(MatmulTranspose, Basic) {
   EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
   result.Destroy();
 
+  RTNAME(MatmulTransposeInteger2Integer8)
+  (result, sectionZ2, *v, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 1);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
+  EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
+  result.Destroy();
+
   // X F F    Y F T    V T F T
   //   T F      F T
   //   T T      F F
@@ -222,6 +330,25 @@ TEST(MatmulTranspose, Basic) {
       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
   EXPECT_FALSE(
       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
+  result.Destroy();
+
+  RTNAME(MatmulTransposeLogical1Logical2)
+  (result, *xLog, *yLog, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 2);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+  EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
+  EXPECT_FALSE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
+  EXPECT_FALSE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
+  EXPECT_TRUE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
+  EXPECT_FALSE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
+  result.Destroy();
 
   RTNAME(MatmulTranspose)(result, *yLog, *vLog, __FILE__, __LINE__);
   ASSERT_EQ(result.rank(), 1);
@@ -232,4 +359,17 @@ TEST(MatmulTranspose, Basic) {
       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
   EXPECT_TRUE(
       static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
+  result.Destroy();
+
+  RTNAME(MatmulTransposeLogical2Logical1)
+  (result, *yLog, *vLog, __FILE__, __LINE__);
+  ASSERT_EQ(result.rank(), 1);
+  EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+  EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+  ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
+  EXPECT_FALSE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
+  EXPECT_TRUE(
+      static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
+  result.Destroy();
 }



More information about the flang-commits mailing list