[flang-commits] [flang] 4d97717 - [flang] Improved performance of runtime Matmul/MatmulTranspose.
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Tue Aug 29 17:04:22 PDT 2023
Author: Slava Zakharin
Date: 2023-08-29T17:04:00-07:00
New Revision: 4d9771741d40cc9cfcccb6b033f43689d36b705a
URL: https://github.com/llvm/llvm-project/commit/4d9771741d40cc9cfcccb6b033f43689d36b705a
DIFF: https://github.com/llvm/llvm-project/commit/4d9771741d40cc9cfcccb6b033f43689d36b705a.diff
LOG: [flang] Improved performance of runtime Matmul/MatmulTranspose.
This patch mostly affects performance of the code produced by
HLIFR lowering. If MATMUL argument is an array slice, then
HLFIR lowering passes the slice to the runtime, whereas
FIR lowering would create a contiguous temporary for the slice.
Performance might be better than the generic implementation
for cases where the leading dimension is contiguous.
This patch improves CPU2000/178.galgel making HLFIR version
faster than FIR version (due to avoiding the temporary copies
for MATMUL arguments).
Reviewed By: klausler
Differential Revision: https://reviews.llvm.org/D159134
Added:
Modified:
flang/runtime/matmul-transpose.cpp
flang/runtime/matmul.cpp
flang/unittests/Runtime/Matmul.cpp
flang/unittests/Runtime/MatmulTranspose.cpp
Removed:
################################################################################
diff --git a/flang/runtime/matmul-transpose.cpp b/flang/runtime/matmul-transpose.cpp
index 1a31ccc4591cd1..43fcf7c0849069 100644
--- a/flang/runtime/matmul-transpose.cpp
+++ b/flang/runtime/matmul-transpose.cpp
@@ -52,25 +52,64 @@ using namespace Fortran::runtime;
// DO 2 I = 1, NROWS
// DO 2 K = 1, N
// 2 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) ! loop-invariant last term
-template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
+ bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS>
inline static void MatrixTransposedTimesMatrix(
CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
- SubscriptValue n) {
+ SubscriptValue n, std::size_t xColumnByteStride = 0,
+ std::size_t yColumnByteStride = 0) {
using ResultType = CppTypeFor<RCAT, RKIND>;
std::memset(product, 0, rows * cols * sizeof *product);
for (SubscriptValue j{0}; j < cols; ++j) {
for (SubscriptValue i{0}; i < rows; ++i) {
for (SubscriptValue k{0}; k < n; ++k) {
- ResultType x_ki = static_cast<ResultType>(x[i * n + k]);
- ResultType y_kj = static_cast<ResultType>(y[j * n + k]);
+ ResultType x_ki;
+ if constexpr (!X_HAS_STRIDED_COLUMNS) {
+ x_ki = static_cast<ResultType>(x[i * n + k]);
+ } else {
+ x_ki = static_cast<ResultType>(reinterpret_cast<const XT *>(
+ reinterpret_cast<const char *>(x) + i * xColumnByteStride)[k]);
+ }
+ ResultType y_kj;
+ if constexpr (!Y_HAS_STRIDED_COLUMNS) {
+ y_kj = static_cast<ResultType>(y[j * n + k]);
+ } else {
+ y_kj = static_cast<ResultType>(reinterpret_cast<const YT *>(
+ reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]);
+ }
product[j * rows + i] += x_ki * y_kj;
}
}
}
}
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+inline static void MatrixTransposedTimesMatrixHelper(
+ CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
+ SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
+ SubscriptValue n, std::optional<std::size_t> xColumnByteStride,
+ std::optional<std::size_t> yColumnByteStride) {
+ if (!xColumnByteStride) {
+ if (!yColumnByteStride) {
+ MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, false, false>(
+ product, rows, cols, x, y, n);
+ } else {
+ MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, false, true>(
+ product, rows, cols, x, y, n, 0, *yColumnByteStride);
+ }
+ } else {
+ if (!yColumnByteStride) {
+ MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, true, false>(
+ product, rows, cols, x, y, n, *xColumnByteStride);
+ } else {
+ MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT, true, true>(
+ product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride);
+ }
+ }
+}
+
// Contiguous numeric matrix*vector multiplication
// matrix(rows,n) * column vector(n) -> column vector(rows)
// Straightforward algorithm:
@@ -85,21 +124,43 @@ inline static void MatrixTransposedTimesMatrix(
// DO 2 I = 1, NROWS
// DO 2 K = 1, N
// 2 RES(I) = RES(I) + X(K,I)*Y(K)
-template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
+ bool X_HAS_STRIDED_COLUMNS>
inline static void MatrixTransposedTimesVector(
CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
- SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y) {
+ SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y,
+ std::size_t xColumnByteStride = 0) {
using ResultType = CppTypeFor<RCAT, RKIND>;
std::memset(product, 0, rows * sizeof *product);
for (SubscriptValue i{0}; i < rows; ++i) {
for (SubscriptValue k{0}; k < n; ++k) {
- ResultType x_ki = static_cast<ResultType>(x[i * n + k]);
+ ResultType x_ki;
+ if constexpr (!X_HAS_STRIDED_COLUMNS) {
+ x_ki = static_cast<ResultType>(x[i * n + k]);
+ } else {
+ x_ki = static_cast<ResultType>(reinterpret_cast<const XT *>(
+ reinterpret_cast<const char *>(x) + i * xColumnByteStride)[k]);
+ }
ResultType y_k = static_cast<ResultType>(y[k]);
product[i] += x_ki * y_k;
}
}
}
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+inline static void MatrixTransposedTimesVectorHelper(
+ CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
+ SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y,
+ std::optional<std::size_t> xColumnByteStride) {
+ if (!xColumnByteStride) {
+ MatrixTransposedTimesVector<RCAT, RKIND, XT, YT, false>(
+ product, rows, n, x, y);
+ } else {
+ MatrixTransposedTimesVector<RCAT, RKIND, XT, YT, true>(
+ product, rows, n, x, y, *xColumnByteStride);
+ }
+}
+
// Implements an instance of MATMUL for given argument types.
template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
typename YT>
@@ -149,19 +210,39 @@ inline static void DoMatmulTranspose(
const SubscriptValue rows{extent[0]};
const SubscriptValue cols{extent[1]};
if constexpr (RCAT != TypeCategory::Logical) {
- if (x.IsContiguous() && y.IsContiguous() &&
+ if (x.IsContiguous(1) && y.IsContiguous(1) &&
(IS_ALLOCATING || result.IsContiguous())) {
- // Contiguous numeric matrices
+ // Contiguous numeric matrices (maybe with columns
+ // separated by a stride).
+ std::optional<std::size_t> xColumnByteStride;
+ if (!x.IsContiguous()) {
+ // X's columns are strided.
+ SubscriptValue xAt[2]{};
+ x.GetLowerBounds(xAt);
+ xAt[1]++;
+ xColumnByteStride = x.SubscriptsToByteOffset(xAt);
+ }
+ std::optional<std::size_t> yColumnByteStride;
+ if (!y.IsContiguous()) {
+ // Y's columns are strided.
+ SubscriptValue yAt[2]{};
+ y.GetLowerBounds(yAt);
+ yAt[1]++;
+ yColumnByteStride = y.SubscriptsToByteOffset(yAt);
+ }
if (resRank == 2) { // M*M -> M
- MatrixTransposedTimesMatrix<RCAT, RKIND, XT, YT>(
+ // TODO: use BLAS-3 GEMM for supported types.
+ MatrixTransposedTimesMatrixHelper<RCAT, RKIND, XT, YT>(
result.template OffsetElement<WriteResult>(), rows, cols,
- x.OffsetElement<XT>(), y.OffsetElement<YT>(), n);
+ x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride,
+ yColumnByteStride);
return;
}
if (xRank == 2) { // M*V -> V
- MatrixTransposedTimesVector<RCAT, RKIND, XT, YT>(
+ // TODO: use BLAS-2 GEMM for supported types.
+ MatrixTransposedTimesVectorHelper<RCAT, RKIND, XT, YT>(
result.template OffsetElement<WriteResult>(), rows, n,
- x.OffsetElement<XT>(), y.OffsetElement<YT>());
+ x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride);
return;
}
// else V*M -> V (not allowed because TRANSPOSE() is only defined for rank
diff --git a/flang/runtime/matmul.cpp b/flang/runtime/matmul.cpp
index df260e1fa5ebd1..b46a94de01ceda 100644
--- a/flang/runtime/matmul.cpp
+++ b/flang/runtime/matmul.cpp
@@ -69,10 +69,12 @@ class Accumulator {
// DO 2 J = 1, NCOLS
// DO 2 I = 1, NROWS
// 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term
-template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
+ bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS>
inline void MatrixTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT x,
- const YT *RESTRICT y, SubscriptValue n) {
+ const YT *RESTRICT y, SubscriptValue n, std::size_t xColumnByteStride = 0,
+ std::size_t yColumnByteStride = 0) {
using ResultType = CppTypeFor<RCAT, RKIND>;
std::memset(product, 0, rows * cols * sizeof *product);
const XT *RESTRICT xp0{x};
@@ -80,12 +82,48 @@ inline void MatrixTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
ResultType *RESTRICT p{product};
for (SubscriptValue j{0}; j < cols; ++j) {
const XT *RESTRICT xp{xp0};
- auto yv{static_cast<ResultType>(y[k + j * n])};
+ ResultType yv;
+ if constexpr (!Y_HAS_STRIDED_COLUMNS) {
+ yv = static_cast<ResultType>(y[k + j * n]);
+ } else {
+ yv = static_cast<ResultType>(reinterpret_cast<const YT *>(
+ reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]);
+ }
for (SubscriptValue i{0}; i < rows; ++i) {
*p++ += static_cast<ResultType>(*xp++) * yv;
}
}
- xp0 += rows;
+ if constexpr (!X_HAS_STRIDED_COLUMNS) {
+ xp0 += rows;
+ } else {
+ xp0 = reinterpret_cast<const XT *>(
+ reinterpret_cast<const char *>(xp0) + xColumnByteStride);
+ }
+ }
+}
+
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+inline void MatrixTimesMatrixHelper(CppTypeFor<RCAT, RKIND> *RESTRICT product,
+ SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT x,
+ const YT *RESTRICT y, SubscriptValue n,
+ std::optional<std::size_t> xColumnByteStride,
+ std::optional<std::size_t> yColumnByteStride) {
+ if (!xColumnByteStride) {
+ if (!yColumnByteStride) {
+ MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, false>(
+ product, rows, cols, x, y, n);
+ } else {
+ MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, true>(
+ product, rows, cols, x, y, n, 0, *yColumnByteStride);
+ }
+ } else {
+ if (!yColumnByteStride) {
+ MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, false>(
+ product, rows, cols, x, y, n, *xColumnByteStride);
+ } else {
+ MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, true>(
+ product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride);
+ }
}
}
@@ -103,18 +141,37 @@ inline void MatrixTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
// DO 2 K = 1, N
// DO 2 J = 1, NROWS
// 2 RES(J) = RES(J) + X(J,K)*Y(K)
-template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
+ bool X_HAS_STRIDED_COLUMNS>
inline void MatrixTimesVector(CppTypeFor<RCAT, RKIND> *RESTRICT product,
SubscriptValue rows, SubscriptValue n, const XT *RESTRICT x,
- const YT *RESTRICT y) {
+ const YT *RESTRICT y, std::size_t xColumnByteStride = 0) {
using ResultType = CppTypeFor<RCAT, RKIND>;
std::memset(product, 0, rows * sizeof *product);
+ [[maybe_unused]] const XT *RESTRICT xp0{x};
for (SubscriptValue k{0}; k < n; ++k) {
ResultType *RESTRICT p{product};
auto yv{static_cast<ResultType>(*y++)};
for (SubscriptValue j{0}; j < rows; ++j) {
*p++ += static_cast<ResultType>(*x++) * yv;
}
+ if constexpr (X_HAS_STRIDED_COLUMNS) {
+ xp0 = reinterpret_cast<const XT *>(
+ reinterpret_cast<const char *>(xp0) + xColumnByteStride);
+ x = xp0;
+ }
+ }
+}
+
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+inline void MatrixTimesVectorHelper(CppTypeFor<RCAT, RKIND> *RESTRICT product,
+ SubscriptValue rows, SubscriptValue n, const XT *RESTRICT x,
+ const YT *RESTRICT y, std::optional<std::size_t> xColumnByteStride) {
+ if (!xColumnByteStride) {
+ MatrixTimesVector<RCAT, RKIND, XT, YT, false>(product, rows, n, x, y);
+ } else {
+ MatrixTimesVector<RCAT, RKIND, XT, YT, true>(
+ product, rows, n, x, y, *xColumnByteStride);
}
}
@@ -132,10 +189,11 @@ inline void MatrixTimesVector(CppTypeFor<RCAT, RKIND> *RESTRICT product,
// DO 2 K = 1, N
// DO 2 J = 1, NCOLS
// 2 RES(J) = RES(J) + X(K)*Y(K,J)
-template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
+ bool Y_HAS_STRIDED_COLUMNS>
inline void VectorTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
SubscriptValue n, SubscriptValue cols, const XT *RESTRICT x,
- const YT *RESTRICT y) {
+ const YT *RESTRICT y, std::size_t yColumnByteStride = 0) {
using ResultType = CppTypeFor<RCAT, RKIND>;
std::memset(product, 0, cols * sizeof *product);
for (SubscriptValue k{0}; k < n; ++k) {
@@ -144,11 +202,29 @@ inline void VectorTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT product,
const YT *RESTRICT yp{&y[k]};
for (SubscriptValue j{0}; j < cols; ++j) {
*p++ += xv * static_cast<ResultType>(*yp);
- yp += n;
+ if constexpr (!Y_HAS_STRIDED_COLUMNS) {
+ yp += n;
+ } else {
+ yp = reinterpret_cast<const YT *>(
+ reinterpret_cast<const char *>(yp) + yColumnByteStride);
+ }
}
}
}
+template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
+ bool SPARSE_COLUMNS = false>
+inline void VectorTimesMatrixHelper(CppTypeFor<RCAT, RKIND> *RESTRICT product,
+ SubscriptValue n, SubscriptValue cols, const XT *RESTRICT x,
+ const YT *RESTRICT y, std::optional<std::size_t> yColumnByteStride) {
+ if (!yColumnByteStride) {
+ VectorTimesMatrix<RCAT, RKIND, XT, YT, false>(product, n, cols, x, y);
+ } else {
+ VectorTimesMatrix<RCAT, RKIND, XT, YT, true>(
+ product, n, cols, x, y, *yColumnByteStride);
+ }
+}
+
// Implements an instance of MATMUL for given argument types.
template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
typename YT>
@@ -194,13 +270,35 @@ static inline void DoMatmul(
CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
RKIND>;
if constexpr (RCAT != TypeCategory::Logical) {
- if (x.IsContiguous() && y.IsContiguous() &&
+ if (x.IsContiguous(1) && y.IsContiguous(1) &&
(IS_ALLOCATING || result.IsContiguous())) {
- // Contiguous numeric matrices
+ // Contiguous numeric matrices (maybe with columns
+ // separated by a stride).
+ std::optional<std::size_t> xColumnByteStride;
+ if (!x.IsContiguous()) {
+ // X's columns are strided.
+ SubscriptValue xAt[2]{};
+ x.GetLowerBounds(xAt);
+ xAt[1]++;
+ xColumnByteStride = x.SubscriptsToByteOffset(xAt);
+ }
+ std::optional<std::size_t> yColumnByteStride;
+ if (!y.IsContiguous()) {
+ // Y's columns are strided.
+ SubscriptValue yAt[2]{};
+ y.GetLowerBounds(yAt);
+ yAt[1]++;
+ yColumnByteStride = y.SubscriptsToByteOffset(yAt);
+ }
+ // Note that BLAS GEMM can be used for the strided
+ // columns by setting proper leading dimension size.
+ // This implies that the column stride is divisible
+ // by the element size, which is usually true.
if (resRank == 2) { // M*M -> M
if (std::is_same_v<XT, YT>) {
if constexpr (std::is_same_v<XT, float>) {
// TODO: call BLAS-3 SGEMM
+ // TODO: try using CUTLASS for device.
} else if constexpr (std::is_same_v<XT, double>) {
// TODO: call BLAS-3 DGEMM
} else if constexpr (std::is_same_v<XT, std::complex<float>>) {
@@ -209,9 +307,10 @@ static inline void DoMatmul(
// TODO: call BLAS-3 ZGEMM
}
}
- MatrixTimesMatrix<RCAT, RKIND, XT, YT>(
+ MatrixTimesMatrixHelper<RCAT, RKIND, XT, YT>(
result.template OffsetElement<WriteResult>(), extent[0], extent[1],
- x.OffsetElement<XT>(), y.OffsetElement<YT>(), n);
+ x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride,
+ yColumnByteStride);
return;
} else if (xRank == 2) { // M*V -> V
if (std::is_same_v<XT, YT>) {
@@ -225,9 +324,9 @@ static inline void DoMatmul(
// TODO: call BLAS-2 ZGEMV(x,y)
}
}
- MatrixTimesVector<RCAT, RKIND, XT, YT>(
+ MatrixTimesVectorHelper<RCAT, RKIND, XT, YT>(
result.template OffsetElement<WriteResult>(), extent[0], n,
- x.OffsetElement<XT>(), y.OffsetElement<YT>());
+ x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride);
return;
} else { // V*M -> V
if (std::is_same_v<XT, YT>) {
@@ -241,9 +340,9 @@ static inline void DoMatmul(
// TODO: call BLAS-2 ZGEMV(y,x)
}
}
- VectorTimesMatrix<RCAT, RKIND, XT, YT>(
+ VectorTimesMatrixHelper<RCAT, RKIND, XT, YT>(
result.template OffsetElement<WriteResult>(), n, extent[0],
- x.OffsetElement<XT>(), y.OffsetElement<YT>());
+ x.OffsetElement<XT>(), y.OffsetElement<YT>(), yColumnByteStride);
return;
}
}
diff --git a/flang/unittests/Runtime/Matmul.cpp b/flang/unittests/Runtime/Matmul.cpp
index 30ce3d8a88825f..1d6c5ccc609b42 100644
--- a/flang/unittests/Runtime/Matmul.cpp
+++ b/flang/unittests/Runtime/Matmul.cpp
@@ -27,6 +27,16 @@ TEST(Matmul, Basic) {
std::vector<int>{3, 2}, std::vector<std::int16_t>{6, 7, 8, 9, 10, 11})};
auto v{MakeArray<TypeCategory::Integer, 8>(
std::vector<int>{2}, std::vector<std::int64_t>{-1, -2})};
+
+ // X2 0 2 4 Y2 -1 -1
+ // 1 3 5 6 9
+ // -1 -1 -1 7 10
+ // 8 11
+ auto x2{MakeArray<TypeCategory::Integer, 4>(std::vector<int>{3, 3},
+ std::vector<std::int32_t>{0, 1, -1, 2, 3, -1, 4, 5})};
+ auto y2{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{4, 2},
+ std::vector<std::int16_t>{-1, 6, 7, 8, -1, 9, 10, 11})};
+
StaticDescriptor<2, true> statDesc;
Descriptor &result{statDesc.descriptor()};
@@ -73,6 +83,98 @@ TEST(Matmul, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
result.Destroy();
+ // Test non-contiguous sections.
+ static constexpr int sectionRank{2};
+ StaticDescriptor<sectionRank> sectionStaticDescriptorX2;
+ Descriptor §ionX2{sectionStaticDescriptorX2.descriptor()};
+ sectionX2.Establish(x2->type(), x2->ElementBytes(),
+ /*p=*/nullptr, /*rank=*/sectionRank);
+ static const SubscriptValue lowersX2[]{1, 1}, uppersX2[]{2, 3};
+ // Section of X2:
+ // +--------+
+ // | 0 2 4|
+ // | 1 3 5|
+ // +--------+
+ // -1 -1 -1
+ const auto errorX2{CFI_section(
+ §ionX2.raw(), &x2->raw(), lowersX2, uppersX2, /*strides=*/nullptr)};
+ ASSERT_EQ(errorX2, 0) << "CFI_section failed for X2: " << errorX2;
+
+ StaticDescriptor<sectionRank> sectionStaticDescriptorY2;
+ Descriptor §ionY2{sectionStaticDescriptorY2.descriptor()};
+ sectionY2.Establish(y2->type(), y2->ElementBytes(),
+ /*p=*/nullptr, /*rank=*/sectionRank);
+ static const SubscriptValue lowersY2[]{2, 1};
+ // Section of Y2:
+ // -1 -1
+ // +-----+
+ // | 6 9|
+ // | 7 10|
+ // | 8 11|
+ // +-----+
+ const auto errorY2{CFI_section(§ionY2.raw(), &y2->raw(), lowersY2,
+ /*uppers=*/nullptr, /*strides=*/nullptr)};
+ ASSERT_EQ(errorY2, 0) << "CFI_section failed for Y2: " << errorY2;
+
+ RTNAME(Matmul)(result, sectionX2, *y, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+ result.Destroy();
+
+ RTNAME(Matmul)(result, *x, sectionY2, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+ result.Destroy();
+
+ RTNAME(Matmul)(result, sectionX2, sectionY2, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+ result.Destroy();
+
+ RTNAME(Matmul)(result, *v, sectionX2, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 1);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14);
+ result.Destroy();
+
+ RTNAME(Matmul)(result, sectionY2, *v, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 1);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
+ result.Destroy();
+
// X F F T Y F T
// F T T F T
// F F
diff --git a/flang/unittests/Runtime/MatmulTranspose.cpp b/flang/unittests/Runtime/MatmulTranspose.cpp
index 83db1328963a69..2362887c414ecc 100644
--- a/flang/unittests/Runtime/MatmulTranspose.cpp
+++ b/flang/unittests/Runtime/MatmulTranspose.cpp
@@ -32,6 +32,17 @@ TEST(MatmulTranspose, Basic) {
std::vector<std::int16_t>{0, 0, 0, 1, 1, 0, 1, 1})};
auto v{MakeArray<TypeCategory::Integer, 8>(
std::vector<int>{2}, std::vector<std::int64_t>{-1, -2})};
+ // X2 0 1 Y2 -1 -1 Z2 6 7 8
+ // 2 3 6 9 9 10 11
+ // 4 5 7 10 -1 -1 -1
+ // -1 -1 8 11
+ auto x2{MakeArray<TypeCategory::Integer, 4>(std::vector<int>{4, 2},
+ std::vector<std::int32_t>{0, 2, 4, -1, 1, 3, 5, -1})};
+ auto y2{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{4, 2},
+ std::vector<std::int16_t>{-1, 6, 7, 8, -1, 9, 10, 11})};
+ auto z2{MakeArray<TypeCategory::Integer, 2>(std::vector<int>{3, 3},
+ std::vector<std::int16_t>{6, 9, -1, 7, 10, -1, 8, 11, -1})};
+
StaticDescriptor<2, true> statDesc;
Descriptor &result{statDesc.descriptor()};
@@ -89,6 +100,104 @@ TEST(MatmulTranspose, Basic) {
EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int16_t>(11), 19);
result.Destroy();
+ // Test non-contiguous sections.
+ static constexpr int sectionRank{2};
+ StaticDescriptor<sectionRank> sectionStaticDescriptorX2;
+ Descriptor §ionX2{sectionStaticDescriptorX2.descriptor()};
+ sectionX2.Establish(x2->type(), x2->ElementBytes(),
+ /*p=*/nullptr, /*rank=*/sectionRank);
+ static const SubscriptValue lowersX2[]{1, 1}, uppersX2[]{3, 2};
+ // Section of X2:
+ // +-----+
+ // | 0 1|
+ // | 2 3|
+ // | 4 5|
+ // +-----+
+ // -1 -1
+ const auto errorX2{CFI_section(
+ §ionX2.raw(), &x2->raw(), lowersX2, uppersX2, /*strides=*/nullptr)};
+ ASSERT_EQ(errorX2, 0) << "CFI_section failed for X2: " << errorX2;
+
+ StaticDescriptor<sectionRank> sectionStaticDescriptorY2;
+ Descriptor §ionY2{sectionStaticDescriptorY2.descriptor()};
+ sectionY2.Establish(y2->type(), y2->ElementBytes(),
+ /*p=*/nullptr, /*rank=*/sectionRank);
+ static const SubscriptValue lowersY2[]{2, 1};
+ // Section of Y2:
+ // -1 -1
+ // +-----+
+ // | 6 0|
+ // | 7 10|
+ // | 8 11|
+ // +-----+
+ const auto errorY2{CFI_section(§ionY2.raw(), &y2->raw(), lowersY2,
+ /*uppers=*/nullptr, /*strides=*/nullptr)};
+ ASSERT_EQ(errorY2, 0) << "CFI_section failed for Y2: " << errorY2;
+
+ StaticDescriptor<sectionRank> sectionStaticDescriptorZ2;
+ Descriptor §ionZ2{sectionStaticDescriptorZ2.descriptor()};
+ sectionZ2.Establish(z2->type(), z2->ElementBytes(),
+ /*p=*/nullptr, /*rank=*/sectionRank);
+ static const SubscriptValue lowersZ2[]{1, 1}, uppersZ2[]{2, 3};
+ // Section of Z2:
+ // +--------+
+ // | 6 7 8|
+ // | 9 10 11|
+ // +--------+
+ // -1 -1 -1
+ const auto errorZ2{CFI_section(
+ §ionZ2.raw(), &z2->raw(), lowersZ2, uppersZ2, /*strides=*/nullptr)};
+ ASSERT_EQ(errorZ2, 0) << "CFI_section failed for Z2: " << errorZ2;
+
+ RTNAME(MatmulTranspose)(result, sectionX2, *y, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+ result.Destroy();
+
+ RTNAME(MatmulTranspose)(result, *x, sectionY2, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+ result.Destroy();
+
+ RTNAME(MatmulTranspose)(result, sectionX2, sectionY2, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 2);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 2);
+ EXPECT_EQ(result.GetDimension(1).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(1).Extent(), 2);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94);
+ result.Destroy();
+
+ RTNAME(MatmulTranspose)(result, sectionZ2, *v, __FILE__, __LINE__);
+ ASSERT_EQ(result.rank(), 1);
+ EXPECT_EQ(result.GetDimension(0).LowerBound(), 1);
+ EXPECT_EQ(result.GetDimension(0).Extent(), 3);
+ ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8}));
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27);
+ EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30);
+ result.Destroy();
+
// X F F Y F T V T F T
// T F F T
// T T F F
More information about the flang-commits
mailing list