[flang-commits] [flang] dd22085 - [flang][runtime] Split MATMUL[_TRANSPOSE] into separate entries. (#97406)
via flang-commits
flang-commits at lists.llvm.org
Tue Jul 2 21:30:40 PDT 2024
Author: Slava Zakharin
Date: 2024-07-02T21:30:37-07:00
New Revision: dd220853081400db6b4f85027030645115229ba0
URL: https://github.com/llvm/llvm-project/commit/dd220853081400db6b4f85027030645115229ba0
DIFF: https://github.com/llvm/llvm-project/commit/dd220853081400db6b4f85027030645115229ba0.diff
LOG: [flang][runtime] Split MATMUL[_TRANSPOSE] into separate entries. (#97406)
Device compilation is much faster for separate MATMUL[_TRANPOSE]
entries than for a single one that covers all data types.
The lowering changes and the removal of the generic entries will follow.
Added:
flang/include/flang/Runtime/matmul-instances.inc
Modified:
flang/include/flang/Runtime/matmul-transpose.h
flang/include/flang/Runtime/matmul.h
flang/runtime/matmul-transpose.cpp
flang/runtime/matmul.cpp
flang/unittests/Runtime/Matmul.cpp
flang/unittests/Runtime/MatmulTranspose.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Runtime/matmul-instances.inc b/flang/include/flang/Runtime/matmul-instances.inc
new file mode 100644
index 0000000000000..970b03339cd5e
--- /dev/null
+++ b/flang/include/flang/Runtime/matmul-instances.inc
@@ -0,0 +1,261 @@
+//===-- include/flang/Runtime/matmul-instances.inc --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Helper macros to instantiate MATMUL/MATMUL_TRANSPOSE definitions
+// for
diff erent data types of the input arguments.
+//===----------------------------------------------------------------------===//
+
+#ifndef MATMUL_INSTANCE
+#error "Define MATMUL_INSTANCE before including this file"
+#endif
+
+#ifndef MATMUL_DIRECT_INSTANCE
+#error "Define MATMUL_DIRECT_INSTANCE before including this file"
+#endif
+
+// clang-format off
+
+#define FOREACH_MATMUL_TYPE_PAIR(macro) \
+ macro(Integer, 1, Integer, 1) \
+ macro(Integer, 1, Integer, 2) \
+ macro(Integer, 1, Integer, 4) \
+ macro(Integer, 1, Integer, 8) \
+ macro(Integer, 2, Integer, 1) \
+ macro(Integer, 2, Integer, 2) \
+ macro(Integer, 2, Integer, 4) \
+ macro(Integer, 2, Integer, 8) \
+ macro(Integer, 4, Integer, 1) \
+ macro(Integer, 4, Integer, 2) \
+ macro(Integer, 4, Integer, 4) \
+ macro(Integer, 4, Integer, 8) \
+ macro(Integer, 8, Integer, 1) \
+ macro(Integer, 8, Integer, 2) \
+ macro(Integer, 8, Integer, 4) \
+ macro(Integer, 8, Integer, 8) \
+ macro(Integer, 1, Real, 4) \
+ macro(Integer, 1, Real, 8) \
+ macro(Integer, 2, Real, 4) \
+ macro(Integer, 2, Real, 8) \
+ macro(Integer, 4, Real, 4) \
+ macro(Integer, 4, Real, 8) \
+ macro(Integer, 8, Real, 4) \
+ macro(Integer, 8, Real, 8) \
+ macro(Integer, 1, Complex, 4) \
+ macro(Integer, 1, Complex, 8) \
+ macro(Integer, 2, Complex, 4) \
+ macro(Integer, 2, Complex, 8) \
+ macro(Integer, 4, Complex, 4) \
+ macro(Integer, 4, Complex, 8) \
+ macro(Integer, 8, Complex, 4) \
+ macro(Integer, 8, Complex, 8) \
+ macro(Real, 4, Integer, 1) \
+ macro(Real, 4, Integer, 2) \
+ macro(Real, 4, Integer, 4) \
+ macro(Real, 4, Integer, 8) \
+ macro(Real, 8, Integer, 1) \
+ macro(Real, 8, Integer, 2) \
+ macro(Real, 8, Integer, 4) \
+ macro(Real, 8, Integer, 8) \
+ macro(Real, 4, Real, 4) \
+ macro(Real, 4, Real, 8) \
+ macro(Real, 8, Real, 4) \
+ macro(Real, 8, Real, 8) \
+ macro(Real, 4, Complex, 4) \
+ macro(Real, 4, Complex, 8) \
+ macro(Real, 8, Complex, 4) \
+ macro(Real, 8, Complex, 8) \
+ macro(Complex, 4, Integer, 1) \
+ macro(Complex, 4, Integer, 2) \
+ macro(Complex, 4, Integer, 4) \
+ macro(Complex, 4, Integer, 8) \
+ macro(Complex, 8, Integer, 1) \
+ macro(Complex, 8, Integer, 2) \
+ macro(Complex, 8, Integer, 4) \
+ macro(Complex, 8, Integer, 8) \
+ macro(Complex, 4, Real, 4) \
+ macro(Complex, 4, Real, 8) \
+ macro(Complex, 8, Real, 4) \
+ macro(Complex, 8, Real, 8) \
+ macro(Complex, 4, Complex, 4) \
+ macro(Complex, 4, Complex, 8) \
+ macro(Complex, 8, Complex, 4) \
+ macro(Complex, 8, Complex, 8) \
+
+FOREACH_MATMUL_TYPE_PAIR(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
+
+#if defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
+#define FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(macro) \
+ macro(Integer, 16, Integer, 1) \
+ macro(Integer, 16, Integer, 2) \
+ macro(Integer, 16, Integer, 4) \
+ macro(Integer, 16, Integer, 8) \
+ macro(Integer, 16, Integer, 16) \
+ macro(Integer, 16, Real, 4) \
+ macro(Integer, 16, Real, 8) \
+ macro(Integer, 16, Complex, 4) \
+ macro(Integer, 16, Complex, 8) \
+ macro(Real, 4, Integer, 16) \
+ macro(Real, 8, Integer, 16) \
+ macro(Complex, 4, Integer, 16) \
+ macro(Complex, 8, Integer, 16) \
+
+FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_DIRECT_INSTANCE)
+
+#if LDBL_MANT_DIG == 64
+MATMUL_INSTANCE(Integer, 16, Real, 10)
+MATMUL_INSTANCE(Integer, 16, Complex, 10)
+MATMUL_INSTANCE(Real, 10, Integer, 16)
+MATMUL_INSTANCE(Complex, 10, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Real, 10)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 10)
+MATMUL_DIRECT_INSTANCE(Real, 10, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 10, Integer, 16)
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+MATMUL_INSTANCE(Integer, 16, Real, 16)
+MATMUL_INSTANCE(Integer, 16, Complex, 16)
+MATMUL_INSTANCE(Real, 16, Integer, 16)
+MATMUL_INSTANCE(Complex, 16, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Real, 16)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 16)
+MATMUL_DIRECT_INSTANCE(Real, 16, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
+#endif
+#endif // defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
+
+#if LDBL_MANT_DIG == 64
+#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(macro) \
+ macro(Integer, 1, Real, 10) \
+ macro(Integer, 1, Complex, 10) \
+ macro(Integer, 2, Real, 10) \
+ macro(Integer, 2, Complex, 10) \
+ macro(Integer, 4, Real, 10) \
+ macro(Integer, 4, Complex, 10) \
+ macro(Integer, 8, Real, 10) \
+ macro(Integer, 8, Complex, 10) \
+ macro(Real, 4, Real, 10) \
+ macro(Real, 4, Complex, 10) \
+ macro(Real, 8, Real, 10) \
+ macro(Real, 8, Complex, 10) \
+ macro(Real, 10, Integer, 1) \
+ macro(Real, 10, Integer, 2) \
+ macro(Real, 10, Integer, 4) \
+ macro(Real, 10, Integer, 8) \
+ macro(Real, 10, Real, 4) \
+ macro(Real, 10, Real, 8) \
+ macro(Real, 10, Real, 10) \
+ macro(Real, 10, Complex, 4) \
+ macro(Real, 10, Complex, 8) \
+ macro(Real, 10, Complex, 10) \
+ macro(Complex, 4, Real, 10) \
+ macro(Complex, 4, Complex, 10) \
+ macro(Complex, 8, Real, 10) \
+ macro(Complex, 8, Complex, 10) \
+ macro(Complex, 10, Integer, 1) \
+ macro(Complex, 10, Integer, 2) \
+ macro(Complex, 10, Integer, 4) \
+ macro(Complex, 10, Integer, 8) \
+ macro(Complex, 10, Real, 4) \
+ macro(Complex, 10, Real, 8) \
+ macro(Complex, 10, Real, 10) \
+ macro(Complex, 10, Complex, 4) \
+ macro(Complex, 10, Complex, 8) \
+ macro(Complex, 10, Complex, 10) \
+
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_DIRECT_INSTANCE)
+
+#if HAS_FLOAT128
+MATMUL_INSTANCE(Real, 10, Real, 16)
+MATMUL_INSTANCE(Real, 10, Complex, 16)
+MATMUL_INSTANCE(Real, 16, Real, 10)
+MATMUL_INSTANCE(Real, 16, Complex, 10)
+MATMUL_INSTANCE(Complex, 10, Real, 16)
+MATMUL_INSTANCE(Complex, 10, Complex, 16)
+MATMUL_INSTANCE(Complex, 16, Real, 10)
+MATMUL_INSTANCE(Complex, 16, Complex, 10)
+MATMUL_DIRECT_INSTANCE(Real, 10, Real, 16)
+MATMUL_DIRECT_INSTANCE(Real, 10, Complex, 16)
+MATMUL_DIRECT_INSTANCE(Real, 16, Real, 10)
+MATMUL_DIRECT_INSTANCE(Real, 16, Complex, 10)
+MATMUL_DIRECT_INSTANCE(Complex, 10, Real, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 10, Complex, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 16, Real, 10)
+MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
+#endif
+#endif // LDBL_MANT_DIG == 64
+
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(macro) \
+ macro(Integer, 1, Real, 16) \
+ macro(Integer, 1, Complex, 16) \
+ macro(Integer, 2, Real, 16) \
+ macro(Integer, 2, Complex, 16) \
+ macro(Integer, 4, Real, 16) \
+ macro(Integer, 4, Complex, 16) \
+ macro(Integer, 8, Real, 16) \
+ macro(Integer, 8, Complex, 16) \
+ macro(Real, 4, Real, 16) \
+ macro(Real, 4, Complex, 16) \
+ macro(Real, 8, Real, 16) \
+ macro(Real, 8, Complex, 16) \
+ macro(Real, 16, Integer, 1) \
+ macro(Real, 16, Integer, 2) \
+ macro(Real, 16, Integer, 4) \
+ macro(Real, 16, Integer, 8) \
+ macro(Real, 16, Real, 4) \
+ macro(Real, 16, Real, 8) \
+ macro(Real, 16, Real, 16) \
+ macro(Real, 16, Complex, 4) \
+ macro(Real, 16, Complex, 8) \
+ macro(Real, 16, Complex, 16) \
+ macro(Complex, 4, Real, 16) \
+ macro(Complex, 4, Complex, 16) \
+ macro(Complex, 8, Real, 16) \
+ macro(Complex, 8, Complex, 16) \
+ macro(Complex, 16, Integer, 1) \
+ macro(Complex, 16, Integer, 2) \
+ macro(Complex, 16, Integer, 4) \
+ macro(Complex, 16, Integer, 8) \
+ macro(Complex, 16, Real, 4) \
+ macro(Complex, 16, Real, 8) \
+ macro(Complex, 16, Real, 16) \
+ macro(Complex, 16, Complex, 4) \
+ macro(Complex, 16, Complex, 8) \
+ macro(Complex, 16, Complex, 16) \
+
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_DIRECT_INSTANCE)
+#endif // LDBL_MANT_DIG == 113 || HAS_FLOAT128
+
+#define FOREACH_MATMUL_LOGICAL_TYPE_PAIR(macro) \
+ macro(Logical, 1, Logical, 1) \
+ macro(Logical, 1, Logical, 2) \
+ macro(Logical, 1, Logical, 4) \
+ macro(Logical, 1, Logical, 8) \
+ macro(Logical, 2, Logical, 1) \
+ macro(Logical, 2, Logical, 2) \
+ macro(Logical, 2, Logical, 4) \
+ macro(Logical, 2, Logical, 8) \
+ macro(Logical, 4, Logical, 1) \
+ macro(Logical, 4, Logical, 2) \
+ macro(Logical, 4, Logical, 4) \
+ macro(Logical, 4, Logical, 8) \
+ macro(Logical, 8, Logical, 1) \
+ macro(Logical, 8, Logical, 2) \
+ macro(Logical, 8, Logical, 4) \
+ macro(Logical, 8, Logical, 8) \
+
+FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_INSTANCE)
+FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
+
+#undef MATMUL_INSTANCE
+#undef MATMUL_DIRECT_INSTANCE
+
+// clang-format on
diff --git a/flang/include/flang/Runtime/matmul-transpose.h b/flang/include/flang/Runtime/matmul-transpose.h
index 5eb5896972e0f..d0a5005a1a8bd 100644
--- a/flang/include/flang/Runtime/matmul-transpose.h
+++ b/flang/include/flang/Runtime/matmul-transpose.h
@@ -10,6 +10,8 @@
#ifndef FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
#define FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
+#include "flang/Common/float128.h"
+#include "flang/Common/uint128.h"
#include "flang/Runtime/entry-names.h"
namespace Fortran::runtime {
class Descriptor;
@@ -25,6 +27,21 @@ void RTDECL(MatmulTranspose)(Descriptor &, const Descriptor &,
// and have a valid base address.
void RTDECL(MatmulTransposeDirect)(const Descriptor &, const Descriptor &,
const Descriptor &, const char *sourceFile = nullptr, int line = 0);
+
+// MATMUL(TRANSPOSE()) versions specialized by the categories of the operand
+// types. The KIND and shape information is taken from the argument's
+// descriptors.
+#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+ void RTDECL(MatmulTranspose##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+ const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+ int line);
+#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+ void RTDECL(MatmulTransposeDirect##XCAT##XKIND##YCAT##YKIND)( \
+ Descriptor & result, const Descriptor &x, const Descriptor &y, \
+ const char *sourceFile, int line);
+
+#include "matmul-instances.inc"
+
} // extern "C"
} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
diff --git a/flang/include/flang/Runtime/matmul.h b/flang/include/flang/Runtime/matmul.h
index 40581d44de9e2..1a5e39eb8813f 100644
--- a/flang/include/flang/Runtime/matmul.h
+++ b/flang/include/flang/Runtime/matmul.h
@@ -10,6 +10,8 @@
#ifndef FORTRAN_RUNTIME_MATMUL_H_
#define FORTRAN_RUNTIME_MATMUL_H_
+#include "flang/Common/float128.h"
+#include "flang/Common/uint128.h"
#include "flang/Runtime/entry-names.h"
namespace Fortran::runtime {
class Descriptor;
@@ -24,6 +26,21 @@ void RTDECL(Matmul)(Descriptor &, const Descriptor &, const Descriptor &,
// and have a valid base address.
void RTDECL(MatmulDirect)(const Descriptor &, const Descriptor &,
const Descriptor &, const char *sourceFile = nullptr, int line = 0);
+
+// MATMUL versions specialized by the categories of the operand types.
+// The KIND and shape information is taken from the argument's
+// descriptors.
+#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+ void RTDECL(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+ const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+ int line);
+#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+ void RTDECL(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+ const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+ int line);
+
+#include "matmul-instances.inc"
+
} // extern "C"
} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_MATMUL_H_
diff --git a/flang/runtime/matmul-transpose.cpp b/flang/runtime/matmul-transpose.cpp
index a12d188266f7c..1c998fa8cf6c1 100644
--- a/flang/runtime/matmul-transpose.cpp
+++ b/flang/runtime/matmul-transpose.cpp
@@ -384,6 +384,30 @@ template <bool IS_ALLOCATING> struct MatmulTranspose {
x, y, terminator, yCatKind->first, yCatKind->second);
}
};
+
+template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT,
+ int YKIND>
+struct MatmulTransposeHelper {
+ using ResultDescriptor =
+ std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
+ RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
+ const Descriptor &y, const char *sourceFile, int line) const {
+ Terminator terminator{sourceFile, line};
+ auto xCatKind{x.type().GetCategoryAndKind()};
+ auto yCatKind{y.type().GetCategoryAndKind()};
+ RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
+ RUNTIME_CHECK(terminator, xCatKind->first == XCAT);
+ RUNTIME_CHECK(terminator, yCatKind->first == YCAT);
+ if constexpr (constexpr auto resultType{
+ GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
+ return DoMatmulTranspose<IS_ALLOCATING, resultType->first,
+ resultType->second, CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>(
+ result, x, y, terminator);
+ }
+ terminator.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))",
+ static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
+ }
+};
} // namespace
namespace Fortran::runtime {
@@ -399,6 +423,24 @@ void RTDEF(MatmulTransposeDirect)(const Descriptor &result, const Descriptor &x,
MatmulTranspose<false>{}(result, x, y, sourceFile, line);
}
+#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+ void RTDEF(MatmulTranspose##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+ const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+ int line) { \
+ MatmulTransposeHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
+ YKIND>{}(result, x, y, sourceFile, line); \
+ }
+
+#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+ void RTDEF(MatmulTransposeDirect##XCAT##XKIND##YCAT##YKIND)( \
+ Descriptor & result, const Descriptor &x, const Descriptor &y, \
+ const char *sourceFile, int line) { \
+ MatmulTransposeHelper<false, TypeCategory::XCAT, XKIND, \
+ TypeCategory::YCAT, YKIND>{}(result, x, y, sourceFile, line); \
+ }
+
+#include "flang/Runtime/matmul-instances.inc"
+
RT_EXT_API_GROUP_END
} // extern "C"
} // namespace Fortran::runtime
diff --git a/flang/runtime/matmul.cpp b/flang/runtime/matmul.cpp
index 8f9b50a549e1f..504d1aa4dc4a4 100644
--- a/flang/runtime/matmul.cpp
+++ b/flang/runtime/matmul.cpp
@@ -28,7 +28,8 @@
#include "flang/Runtime/descriptor.h"
#include <cstring>
-namespace Fortran::runtime {
+namespace {
+using namespace Fortran::runtime;
// Suppress the warnings about calling __host__-only std::complex operators,
// defined in C++ STD header files, from __device__ code.
@@ -455,7 +456,8 @@ template <bool IS_ALLOCATING> struct Matmul {
Terminator &terminator) const {
if constexpr (constexpr auto resultType{
GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
- if constexpr (common::IsNumericTypeCategory(resultType->first) ||
+ if constexpr (Fortran::common::IsNumericTypeCategory(
+ resultType->first) ||
resultType->first == TypeCategory::Logical) {
return DoMatmul<IS_ALLOCATING, resultType->first,
resultType->second, CppTypeFor<XCAT, XKIND>,
@@ -483,6 +485,32 @@ template <bool IS_ALLOCATING> struct Matmul {
}
};
+template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT,
+ int YKIND>
+struct MatmulHelper {
+ using ResultDescriptor =
+ std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
+ RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
+ const Descriptor &y, const char *sourceFile, int line) const {
+ Terminator terminator{sourceFile, line};
+ auto xCatKind{x.type().GetCategoryAndKind()};
+ auto yCatKind{y.type().GetCategoryAndKind()};
+ RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
+ RUNTIME_CHECK(terminator, xCatKind->first == XCAT);
+ RUNTIME_CHECK(terminator, yCatKind->first == YCAT);
+ if constexpr (constexpr auto resultType{
+ GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
+ return DoMatmul<IS_ALLOCATING, resultType->first, resultType->second,
+ CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>(
+ result, x, y, terminator);
+ }
+ terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
+ static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
+ }
+};
+} // namespace
+
+namespace Fortran::runtime {
extern "C" {
RT_EXT_API_GROUP_BEGIN
@@ -495,6 +523,24 @@ void RTDEF(MatmulDirect)(const Descriptor &result, const Descriptor &x,
Matmul<false>{}(result, x, y, sourceFile, line);
}
+#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+ void RTDEF(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+ const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+ int line) { \
+ MatmulHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
+ YKIND>{}(result, x, y, sourceFile, line); \
+ }
+
+#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+ void RTDEF(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+ const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+ int line) { \
+ MatmulHelper<false, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
+ YKIND>{}(result, x, y, sourceFile, line); \
+ }
+
+#include "flang/Runtime/matmul-instances.inc"
+
RT_EXT_API_GROUP_END
} // extern "C"
} // namespace Fortran::runtime
diff --git a/flang/unittests/Runtime/Matmul.cpp b/flang/unittests/Runtime/Matmul.cpp
index 1d6c5ccc609b4..226dbc5ae9eeb 100644
--- a/flang/unittests/Runtime/Matmul.cpp
+++ b/flang/unittests/Runtime/Matmul.cpp
@@ -63,6 +63,29 @@ TEST(Matmul, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
result.Destroy();
+ RTNAME(MatmulInteger4Integer2)(result, *x, *y, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+
+ std::memset(
+ result.raw().base_addr, 0, result.Elements() * result.ElementBytes());
+ result.GetDimension(0).SetLowerBound(0);
+ result.GetDimension(1).SetLowerBound(2);
+ RTNAME(MatmulDirectInteger4Integer2)(result, *x, *y, __FILE__, __LINE__);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+ result.Destroy();
+
RTNAME(Matmul)(result, *v, *x, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 1);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -73,6 +96,16 @@ TEST(Matmul, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
result.Destroy();
+ RTNAME(MatmulInteger8Integer4)(result, *v, *x, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 1);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
+ result.Destroy();
+
RTNAME(Matmul)(result, *y, *v, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 1);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -83,6 +116,16 @@ TEST(Matmul, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
result.Destroy();
+ RTNAME(MatmulInteger2Integer8)(result, *y, *v, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 1);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
+ result.Destroy();
+
// Test non-contiguous sections.
static constexpr int sectionRank{2};
StaticDescriptor<sectionRank> sectionStaticDescriptorX2;
@@ -129,6 +172,19 @@ TEST(Matmul, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
result.Destroy();
+ RTNAME(MatmulInteger4Integer2)(result, sectionX2, *y, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+ result.Destroy();
+
RTNAME(Matmul)(result, *x, sectionY2, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 2);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -142,6 +198,19 @@ TEST(Matmul, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
result.Destroy();
+ RTNAME(MatmulInteger4Integer2)(result, *x, sectionY2, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+ result.Destroy();
+
RTNAME(Matmul)(result, sectionX2, sectionY2, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 2);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -155,6 +224,20 @@ TEST(Matmul, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
result.Destroy();
+ RTNAME(MatmulInteger4Integer2)
+ (result, sectionX2, sectionY2, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+ result.Destroy();
+
RTNAME(Matmul)(result, *v, sectionX2, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 1);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -165,6 +248,16 @@ TEST(Matmul, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
result.Destroy();
+ RTNAME(MatmulInteger8Integer4)(result, *v, sectionX2, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 1);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
+ result.Destroy();
+
RTNAME(Matmul)(result, sectionY2, *v, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 1);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -175,6 +268,16 @@ TEST(Matmul, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
result.Destroy();
+ RTNAME(MatmulInteger2Integer8)(result, sectionY2, *v, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 1);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
+ result.Destroy();
+
// X F F T Y F T
// F T T F T
// F F
@@ -197,4 +300,22 @@ TEST(Matmul, Basic) {
static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
EXPECT_TRUE(
static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
+ result.Destroy();
+
+ RTNAME(MatmulLogical1Logical2)(result, *xLog, *yLog, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
+ EXPECT_FALSE(
+ static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
+ EXPECT_FALSE(
+ static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
+ EXPECT_FALSE(
+ static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
+ EXPECT_TRUE(
+ static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
+ result.Destroy();
}
diff --git a/flang/unittests/Runtime/MatmulTranspose.cpp b/flang/unittests/Runtime/MatmulTranspose.cpp
index fe946f6d5a201..391c2e1b144ea 100644
--- a/flang/unittests/Runtime/MatmulTranspose.cpp
+++ b/flang/unittests/Runtime/MatmulTranspose.cpp
@@ -69,6 +69,30 @@ TEST(MatmulTranspose, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
result.Destroy();
+ RTNAME(MatmulTransposeInteger4Integer2)(result, *x, *y, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+
+ std::memset(
+ result.raw().base_addr, 0, result.Elements() * result.ElementBytes());
+ result.GetDimension(0).SetLowerBound(0);
+ result.GetDimension(1).SetLowerBound(2);
+ RTNAME(MatmulTransposeDirectInteger4Integer2)
+ (result, *x, *y, __FILE__, __LINE__);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+ result.Destroy();
+
RTNAME(MatmulTranspose)(result, *z, *v, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 1);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -79,6 +103,16 @@ TEST(MatmulTranspose, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
result.Destroy();
+ RTNAME(MatmulTransposeInteger2Integer8)(result, *z, *v, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 1);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
+ result.Destroy();
+
RTNAME(MatmulTranspose)(result, *m, *z, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 2);
ASSERT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -100,6 +134,27 @@ TEST(MatmulTranspose, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(11), 19);
result.Destroy();
+ RTNAME(MatmulTransposeInteger2Integer2)(result, *m, *z, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ ASSERT_EQ(result.GetDimension(0).LowerBound(), 1);
+ ASSERT_EQ(result.GetDimension(0).UpperBound(), 4);
+ ASSERT_EQ(result.GetDimension(1).LowerBound(), 1);
+ ASSERT_EQ(result.GetDimension(1).UpperBound(), 3);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 2}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(0), 0);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(1), 9);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(2), 6);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(3), 15);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(4), 0);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(5), 10);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(6), 7);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(7), 17);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(8), 0);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(9), 11);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(10), 8);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(11), 19);
+ result.Destroy();
+
// Test non-contiguous sections.
static constexpr int sectionRank{2};
StaticDescriptor<sectionRank> sectionStaticDescriptorX2;
@@ -162,6 +217,20 @@ TEST(MatmulTranspose, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
result.Destroy();
+ RTNAME(MatmulTransposeInteger4Integer2)
+ (result, sectionX2, *y, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+ result.Destroy();
+
RTNAME(MatmulTranspose)(result, *x, sectionY2, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 2);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -175,6 +244,20 @@ TEST(MatmulTranspose, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
result.Destroy();
+ RTNAME(MatmulTransposeInteger4Integer2)
+ (result, *x, sectionY2, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+ result.Destroy();
+
RTNAME(MatmulTranspose)(result, sectionX2, sectionY2, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 2);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -188,6 +271,20 @@ TEST(MatmulTranspose, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
result.Destroy();
+ RTNAME(MatmulTransposeInteger4Integer2)
+ (result, sectionX2, sectionY2, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+ result.Destroy();
+
RTNAME(MatmulTranspose)(result, sectionZ2, *v, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 1);
EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
@@ -198,6 +295,17 @@ TEST(MatmulTranspose, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
result.Destroy();
+ RTNAME(MatmulTransposeInteger2Integer8)
+ (result, sectionZ2, *v, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 1);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
+ result.Destroy();
+
// X F F Y F T V T F T
// T F F T
// T T F F
@@ -222,6 +330,25 @@ TEST(MatmulTranspose, Basic) {
static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
EXPECT_FALSE(
static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
+ result.Destroy();
+
+ RTNAME(MatmulTransposeLogical1Logical2)
+ (result, *xLog, *yLog, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
+ EXPECT_FALSE(
+ static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
+ EXPECT_FALSE(
+ static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
+ EXPECT_TRUE(
+ static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2)));
+ EXPECT_FALSE(
+ static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3)));
+ result.Destroy();
RTNAME(MatmulTranspose)(result, *yLog, *vLog, __FILE__, __LINE__);
ASSERT_EQ(result.rank(), 1);
@@ -232,4 +359,17 @@ TEST(MatmulTranspose, Basic) {
static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
EXPECT_TRUE(
static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
+ result.Destroy();
+
+ RTNAME(MatmulTransposeLogical2Logical1)
+ (result, *yLog, *vLog, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 1);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2}));
+ EXPECT_FALSE(
+ static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0)));
+ EXPECT_TRUE(
+ static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1)));
+ result.Destroy();
}
More information about the flang-commits
mailing list