[flang-commits] [flang] [flang][runtime] Split MATMUL[_TRANSPOSE] into separate entries. (PR #97406)
via flang-commits
flang-commits at lists.llvm.org
Tue Jul 2 05:07:22 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-runtime
Author: Slava Zakharin (vzakhari)
<details>
<summary>Changes</summary>
Device compilation is much faster for separate MATMUL[_TRANPOSE]
entries than for a single one that covers all data types.
The lowering changes and the removal of the generic entries will follow.
---
Patch is 39.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97406.diff
7 Files Affected:
- (added) flang/include/flang/Runtime/matmul-instances.inc (+261)
- (modified) flang/include/flang/Runtime/matmul-transpose.h (+17)
- (modified) flang/include/flang/Runtime/matmul.h (+17)
- (modified) flang/runtime/matmul-transpose.cpp (+42)
- (modified) flang/runtime/matmul.cpp (+48-2)
- (modified) flang/unittests/Runtime/Matmul.cpp (+121)
- (modified) flang/unittests/Runtime/MatmulTranspose.cpp (+140)
``````````diff
diff --git a/flang/include/flang/Runtime/matmul-instances.inc b/flang/include/flang/Runtime/matmul-instances.inc
new file mode 100644
index 0000000000000..970b03339cd5e
--- /dev/null
+++ b/flang/include/flang/Runtime/matmul-instances.inc
@@ -0,0 +1,261 @@
+//===-- include/flang/Runtime/matmul-instances.inc --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Helper macros to instantiate MATMUL/MATMUL_TRANSPOSE definitions
+// for different data types of the input arguments.
+//===----------------------------------------------------------------------===//
+
+#ifndef MATMUL_INSTANCE
+#error "Define MATMUL_INSTANCE before including this file"
+#endif
+
+#ifndef MATMUL_DIRECT_INSTANCE
+#error "Define MATMUL_DIRECT_INSTANCE before including this file"
+#endif
+
+// clang-format off
+
+#define FOREACH_MATMUL_TYPE_PAIR(macro) \
+ macro(Integer, 1, Integer, 1) \
+ macro(Integer, 1, Integer, 2) \
+ macro(Integer, 1, Integer, 4) \
+ macro(Integer, 1, Integer, 8) \
+ macro(Integer, 2, Integer, 1) \
+ macro(Integer, 2, Integer, 2) \
+ macro(Integer, 2, Integer, 4) \
+ macro(Integer, 2, Integer, 8) \
+ macro(Integer, 4, Integer, 1) \
+ macro(Integer, 4, Integer, 2) \
+ macro(Integer, 4, Integer, 4) \
+ macro(Integer, 4, Integer, 8) \
+ macro(Integer, 8, Integer, 1) \
+ macro(Integer, 8, Integer, 2) \
+ macro(Integer, 8, Integer, 4) \
+ macro(Integer, 8, Integer, 8) \
+ macro(Integer, 1, Real, 4) \
+ macro(Integer, 1, Real, 8) \
+ macro(Integer, 2, Real, 4) \
+ macro(Integer, 2, Real, 8) \
+ macro(Integer, 4, Real, 4) \
+ macro(Integer, 4, Real, 8) \
+ macro(Integer, 8, Real, 4) \
+ macro(Integer, 8, Real, 8) \
+ macro(Integer, 1, Complex, 4) \
+ macro(Integer, 1, Complex, 8) \
+ macro(Integer, 2, Complex, 4) \
+ macro(Integer, 2, Complex, 8) \
+ macro(Integer, 4, Complex, 4) \
+ macro(Integer, 4, Complex, 8) \
+ macro(Integer, 8, Complex, 4) \
+ macro(Integer, 8, Complex, 8) \
+ macro(Real, 4, Integer, 1) \
+ macro(Real, 4, Integer, 2) \
+ macro(Real, 4, Integer, 4) \
+ macro(Real, 4, Integer, 8) \
+ macro(Real, 8, Integer, 1) \
+ macro(Real, 8, Integer, 2) \
+ macro(Real, 8, Integer, 4) \
+ macro(Real, 8, Integer, 8) \
+ macro(Real, 4, Real, 4) \
+ macro(Real, 4, Real, 8) \
+ macro(Real, 8, Real, 4) \
+ macro(Real, 8, Real, 8) \
+ macro(Real, 4, Complex, 4) \
+ macro(Real, 4, Complex, 8) \
+ macro(Real, 8, Complex, 4) \
+ macro(Real, 8, Complex, 8) \
+ macro(Complex, 4, Integer, 1) \
+ macro(Complex, 4, Integer, 2) \
+ macro(Complex, 4, Integer, 4) \
+ macro(Complex, 4, Integer, 8) \
+ macro(Complex, 8, Integer, 1) \
+ macro(Complex, 8, Integer, 2) \
+ macro(Complex, 8, Integer, 4) \
+ macro(Complex, 8, Integer, 8) \
+ macro(Complex, 4, Real, 4) \
+ macro(Complex, 4, Real, 8) \
+ macro(Complex, 8, Real, 4) \
+ macro(Complex, 8, Real, 8) \
+ macro(Complex, 4, Complex, 4) \
+ macro(Complex, 4, Complex, 8) \
+ macro(Complex, 8, Complex, 4) \
+ macro(Complex, 8, Complex, 8) \
+
+FOREACH_MATMUL_TYPE_PAIR(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
+
+#if defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
+#define FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(macro) \
+ macro(Integer, 16, Integer, 1) \
+ macro(Integer, 16, Integer, 2) \
+ macro(Integer, 16, Integer, 4) \
+ macro(Integer, 16, Integer, 8) \
+ macro(Integer, 16, Integer, 16) \
+ macro(Integer, 16, Real, 4) \
+ macro(Integer, 16, Real, 8) \
+ macro(Integer, 16, Complex, 4) \
+ macro(Integer, 16, Complex, 8) \
+ macro(Real, 4, Integer, 16) \
+ macro(Real, 8, Integer, 16) \
+ macro(Complex, 4, Integer, 16) \
+ macro(Complex, 8, Integer, 16) \
+
+FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_DIRECT_INSTANCE)
+
+#if LDBL_MANT_DIG == 64
+MATMUL_INSTANCE(Integer, 16, Real, 10)
+MATMUL_INSTANCE(Integer, 16, Complex, 10)
+MATMUL_INSTANCE(Real, 10, Integer, 16)
+MATMUL_INSTANCE(Complex, 10, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Real, 10)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 10)
+MATMUL_DIRECT_INSTANCE(Real, 10, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 10, Integer, 16)
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+MATMUL_INSTANCE(Integer, 16, Real, 16)
+MATMUL_INSTANCE(Integer, 16, Complex, 16)
+MATMUL_INSTANCE(Real, 16, Integer, 16)
+MATMUL_INSTANCE(Complex, 16, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Real, 16)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 16)
+MATMUL_DIRECT_INSTANCE(Real, 16, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
+#endif
+#endif // defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
+
+#if LDBL_MANT_DIG == 64
+#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(macro) \
+ macro(Integer, 1, Real, 10) \
+ macro(Integer, 1, Complex, 10) \
+ macro(Integer, 2, Real, 10) \
+ macro(Integer, 2, Complex, 10) \
+ macro(Integer, 4, Real, 10) \
+ macro(Integer, 4, Complex, 10) \
+ macro(Integer, 8, Real, 10) \
+ macro(Integer, 8, Complex, 10) \
+ macro(Real, 4, Real, 10) \
+ macro(Real, 4, Complex, 10) \
+ macro(Real, 8, Real, 10) \
+ macro(Real, 8, Complex, 10) \
+ macro(Real, 10, Integer, 1) \
+ macro(Real, 10, Integer, 2) \
+ macro(Real, 10, Integer, 4) \
+ macro(Real, 10, Integer, 8) \
+ macro(Real, 10, Real, 4) \
+ macro(Real, 10, Real, 8) \
+ macro(Real, 10, Real, 10) \
+ macro(Real, 10, Complex, 4) \
+ macro(Real, 10, Complex, 8) \
+ macro(Real, 10, Complex, 10) \
+ macro(Complex, 4, Real, 10) \
+ macro(Complex, 4, Complex, 10) \
+ macro(Complex, 8, Real, 10) \
+ macro(Complex, 8, Complex, 10) \
+ macro(Complex, 10, Integer, 1) \
+ macro(Complex, 10, Integer, 2) \
+ macro(Complex, 10, Integer, 4) \
+ macro(Complex, 10, Integer, 8) \
+ macro(Complex, 10, Real, 4) \
+ macro(Complex, 10, Real, 8) \
+ macro(Complex, 10, Real, 10) \
+ macro(Complex, 10, Complex, 4) \
+ macro(Complex, 10, Complex, 8) \
+ macro(Complex, 10, Complex, 10) \
+
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_DIRECT_INSTANCE)
+
+#if HAS_FLOAT128
+MATMUL_INSTANCE(Real, 10, Real, 16)
+MATMUL_INSTANCE(Real, 10, Complex, 16)
+MATMUL_INSTANCE(Real, 16, Real, 10)
+MATMUL_INSTANCE(Real, 16, Complex, 10)
+MATMUL_INSTANCE(Complex, 10, Real, 16)
+MATMUL_INSTANCE(Complex, 10, Complex, 16)
+MATMUL_INSTANCE(Complex, 16, Real, 10)
+MATMUL_INSTANCE(Complex, 16, Complex, 10)
+MATMUL_DIRECT_INSTANCE(Real, 10, Real, 16)
+MATMUL_DIRECT_INSTANCE(Real, 10, Complex, 16)
+MATMUL_DIRECT_INSTANCE(Real, 16, Real, 10)
+MATMUL_DIRECT_INSTANCE(Real, 16, Complex, 10)
+MATMUL_DIRECT_INSTANCE(Complex, 10, Real, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 10, Complex, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 16, Real, 10)
+MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
+#endif
+#endif // LDBL_MANT_DIG == 64
+
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(macro) \
+ macro(Integer, 1, Real, 16) \
+ macro(Integer, 1, Complex, 16) \
+ macro(Integer, 2, Real, 16) \
+ macro(Integer, 2, Complex, 16) \
+ macro(Integer, 4, Real, 16) \
+ macro(Integer, 4, Complex, 16) \
+ macro(Integer, 8, Real, 16) \
+ macro(Integer, 8, Complex, 16) \
+ macro(Real, 4, Real, 16) \
+ macro(Real, 4, Complex, 16) \
+ macro(Real, 8, Real, 16) \
+ macro(Real, 8, Complex, 16) \
+ macro(Real, 16, Integer, 1) \
+ macro(Real, 16, Integer, 2) \
+ macro(Real, 16, Integer, 4) \
+ macro(Real, 16, Integer, 8) \
+ macro(Real, 16, Real, 4) \
+ macro(Real, 16, Real, 8) \
+ macro(Real, 16, Real, 16) \
+ macro(Real, 16, Complex, 4) \
+ macro(Real, 16, Complex, 8) \
+ macro(Real, 16, Complex, 16) \
+ macro(Complex, 4, Real, 16) \
+ macro(Complex, 4, Complex, 16) \
+ macro(Complex, 8, Real, 16) \
+ macro(Complex, 8, Complex, 16) \
+ macro(Complex, 16, Integer, 1) \
+ macro(Complex, 16, Integer, 2) \
+ macro(Complex, 16, Integer, 4) \
+ macro(Complex, 16, Integer, 8) \
+ macro(Complex, 16, Real, 4) \
+ macro(Complex, 16, Real, 8) \
+ macro(Complex, 16, Real, 16) \
+ macro(Complex, 16, Complex, 4) \
+ macro(Complex, 16, Complex, 8) \
+ macro(Complex, 16, Complex, 16) \
+
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_DIRECT_INSTANCE)
+#endif // LDBL_MANT_DIG == 113 || HAS_FLOAT128
+
+#define FOREACH_MATMUL_LOGICAL_TYPE_PAIR(macro) \
+ macro(Logical, 1, Logical, 1) \
+ macro(Logical, 1, Logical, 2) \
+ macro(Logical, 1, Logical, 4) \
+ macro(Logical, 1, Logical, 8) \
+ macro(Logical, 2, Logical, 1) \
+ macro(Logical, 2, Logical, 2) \
+ macro(Logical, 2, Logical, 4) \
+ macro(Logical, 2, Logical, 8) \
+ macro(Logical, 4, Logical, 1) \
+ macro(Logical, 4, Logical, 2) \
+ macro(Logical, 4, Logical, 4) \
+ macro(Logical, 4, Logical, 8) \
+ macro(Logical, 8, Logical, 1) \
+ macro(Logical, 8, Logical, 2) \
+ macro(Logical, 8, Logical, 4) \
+ macro(Logical, 8, Logical, 8) \
+
+FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_INSTANCE)
+FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
+
+#undef MATMUL_INSTANCE
+#undef MATMUL_DIRECT_INSTANCE
+
+// clang-format on
diff --git a/flang/include/flang/Runtime/matmul-transpose.h b/flang/include/flang/Runtime/matmul-transpose.h
index 5eb5896972e0f..d0a5005a1a8bd 100644
--- a/flang/include/flang/Runtime/matmul-transpose.h
+++ b/flang/include/flang/Runtime/matmul-transpose.h
@@ -10,6 +10,8 @@
#ifndef FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
#define FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
+#include "flang/Common/float128.h"
+#include "flang/Common/uint128.h"
#include "flang/Runtime/entry-names.h"
namespace Fortran::runtime {
class Descriptor;
@@ -25,6 +27,21 @@ void RTDECL(MatmulTranspose)(Descriptor &, const Descriptor &,
// and have a valid base address.
void RTDECL(MatmulTransposeDirect)(const Descriptor &, const Descriptor &,
const Descriptor &, const char *sourceFile = nullptr, int line = 0);
+
+// MATMUL(TRANSPOSE()) versions specialized by the categories of the operand
+// types. The KIND and shape information is taken from the argument's
+// descriptors.
+#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+ void RTDECL(MatmulTranspose##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+ const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+ int line);
+#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+ void RTDECL(MatmulTransposeDirect##XCAT##XKIND##YCAT##YKIND)( \
+ Descriptor & result, const Descriptor &x, const Descriptor &y, \
+ const char *sourceFile, int line);
+
+#include "matmul-instances.inc"
+
} // extern "C"
} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
diff --git a/flang/include/flang/Runtime/matmul.h b/flang/include/flang/Runtime/matmul.h
index 40581d44de9e2..1a5e39eb8813f 100644
--- a/flang/include/flang/Runtime/matmul.h
+++ b/flang/include/flang/Runtime/matmul.h
@@ -10,6 +10,8 @@
#ifndef FORTRAN_RUNTIME_MATMUL_H_
#define FORTRAN_RUNTIME_MATMUL_H_
+#include "flang/Common/float128.h"
+#include "flang/Common/uint128.h"
#include "flang/Runtime/entry-names.h"
namespace Fortran::runtime {
class Descriptor;
@@ -24,6 +26,21 @@ void RTDECL(Matmul)(Descriptor &, const Descriptor &, const Descriptor &,
// and have a valid base address.
void RTDECL(MatmulDirect)(const Descriptor &, const Descriptor &,
const Descriptor &, const char *sourceFile = nullptr, int line = 0);
+
+// MATMUL versions specialized by the categories of the operand types.
+// The KIND and shape information is taken from the argument's
+// descriptors.
+#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+ void RTDECL(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+ const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+ int line);
+#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+ void RTDECL(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+ const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+ int line);
+
+#include "matmul-instances.inc"
+
} // extern "C"
} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_MATMUL_H_
diff --git a/flang/runtime/matmul-transpose.cpp b/flang/runtime/matmul-transpose.cpp
index a12d188266f7c..1c998fa8cf6c1 100644
--- a/flang/runtime/matmul-transpose.cpp
+++ b/flang/runtime/matmul-transpose.cpp
@@ -384,6 +384,30 @@ template <bool IS_ALLOCATING> struct MatmulTranspose {
x, y, terminator, yCatKind->first, yCatKind->second);
}
};
+
+template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT,
+ int YKIND>
+struct MatmulTransposeHelper {
+ using ResultDescriptor =
+ std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
+ RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
+ const Descriptor &y, const char *sourceFile, int line) const {
+ Terminator terminator{sourceFile, line};
+ auto xCatKind{x.type().GetCategoryAndKind()};
+ auto yCatKind{y.type().GetCategoryAndKind()};
+ RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
+ RUNTIME_CHECK(terminator, xCatKind->first == XCAT);
+ RUNTIME_CHECK(terminator, yCatKind->first == YCAT);
+ if constexpr (constexpr auto resultType{
+ GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
+ return DoMatmulTranspose<IS_ALLOCATING, resultType->first,
+ resultType->second, CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>(
+ result, x, y, terminator);
+ }
+ terminator.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))",
+ static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
+ }
+};
} // namespace
namespace Fortran::runtime {
@@ -399,6 +423,24 @@ void RTDEF(MatmulTransposeDirect)(const Descriptor &result, const Descriptor &x,
MatmulTranspose<false>{}(result, x, y, sourceFile, line);
}
+#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+ void RTDEF(MatmulTranspose##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+ const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+ int line) { \
+ MatmulTransposeHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
+ YKIND>{}(result, x, y, sourceFile, line); \
+ }
+
+#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+ void RTDEF(MatmulTransposeDirect##XCAT##XKIND##YCAT##YKIND)( \
+ Descriptor & result, const Descriptor &x, const Descriptor &y, \
+ const char *sourceFile, int line) { \
+ MatmulTransposeHelper<false, TypeCategory::XCAT, XKIND, \
+ TypeCategory::YCAT, YKIND>{}(result, x, y, sourceFile, line); \
+ }
+
+#include "flang/Runtime/matmul-instances.inc"
+
RT_EXT_API_GROUP_END
} // extern "C"
} // namespace Fortran::runtime
diff --git a/flang/runtime/matmul.cpp b/flang/runtime/matmul.cpp
index 8f9b50a549e1f..504d1aa4dc4a4 100644
--- a/flang/runtime/matmul.cpp
+++ b/flang/runtime/matmul.cpp
@@ -28,7 +28,8 @@
#include "flang/Runtime/descriptor.h"
#include <cstring>
-namespace Fortran::runtime {
+namespace {
+using namespace Fortran::runtime;
// Suppress the warnings about calling __host__-only std::complex operators,
// defined in C++ STD header files, from __device__ code.
@@ -455,7 +456,8 @@ template <bool IS_ALLOCATING> struct Matmul {
Terminator &terminator) const {
if constexpr (constexpr auto resultType{
GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
- if constexpr (common::IsNumericTypeCategory(resultType->first) ||
+ if constexpr (Fortran::common::IsNumericTypeCategory(
+ resultType->first) ||
resultType->first == TypeCategory::Logical) {
return DoMatmul<IS_ALLOCATING, resultType->first,
resultType->second, CppTypeFor<XCAT, XKIND>,
@@ -483,6 +485,32 @@ template <bool IS_ALLOCATING> struct Matmul {
}
};
+template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT,
+ int YKIND>
+struct MatmulHelper {
+ using ResultDescriptor =
+ std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
+ RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
+ const Descriptor &y, const char *sourceFile, int line) const {
+ Terminator terminator{sourceFile, line};
+ auto xCatKind{x.type().GetCategoryAndKind()};
+ auto yCatKind{y...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/97406
More information about the flang-commits
mailing list