[flang-commits] [flang] [flang] Fold MATMUL() (PR #72176)
via flang-commits
flang-commits at lists.llvm.org
Mon Nov 13 16:31:11 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-semantics
Author: Peter Klausler (klausler)
<details>
<summary>Changes</summary>
Implements constant folding for matrix multiplication for all four accepted type categories.
---
Full diff: https://github.com/llvm/llvm-project/pull/72176.diff
7 Files Affected:
- (modified) flang/lib/Evaluate/fold-complex.cpp (+3-1)
- (modified) flang/lib/Evaluate/fold-integer.cpp (+3-1)
- (modified) flang/lib/Evaluate/fold-logical.cpp (+3-1)
- (added) flang/lib/Evaluate/fold-matmul.h (+103)
- (modified) flang/lib/Evaluate/fold-real.cpp (+3-1)
- (modified) flang/lib/Evaluate/fold-reduction.h (+2-2)
- (added) flang/test/Evaluate/fold-matmul.f90 (+41)
``````````diff
diff --git a/flang/lib/Evaluate/fold-complex.cpp b/flang/lib/Evaluate/fold-complex.cpp
index e40e3a37df14948..3260f82ffe8d734 100644
--- a/flang/lib/Evaluate/fold-complex.cpp
+++ b/flang/lib/Evaluate/fold-complex.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "fold-implementation.h"
+#include "fold-matmul.h"
#include "fold-reduction.h"
namespace Fortran::evaluate {
@@ -64,13 +65,14 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
}
} else if (name == "dot_product") {
return FoldDotProduct<T>(context, std::move(funcRef));
+ } else if (name == "matmul") {
+ return FoldMatmul(context, std::move(funcRef));
} else if (name == "product") {
auto one{Scalar<Part>::FromInteger(value::Integer<8>{1}).value};
return FoldProduct<T>(context, std::move(funcRef), Scalar<T>{one});
} else if (name == "sum") {
return FoldSum<T>(context, std::move(funcRef));
}
- // TODO: matmul
return Expr<T>{std::move(funcRef)};
}
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index dedfc20a491cd88..2882369105f6626 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "fold-implementation.h"
+#include "fold-matmul.h"
#include "fold-reduction.h"
#include "flang/Evaluate/check-expression.h"
@@ -1042,6 +1043,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
ScalarFunc<T, Int4>([&fptr](const Scalar<Int4> &places) -> Scalar<T> {
return fptr(static_cast<int>(places.ToInt64()));
}));
+ } else if (name == "matmul") {
+ return FoldMatmul(context, std::move(funcRef));
} else if (name == "max") {
return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
} else if (name == "max0" || name == "max1") {
@@ -1279,7 +1282,6 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
} else if (name == "ubound") {
return UBOUND(context, std::move(funcRef));
}
- // TODO: dot_product, matmul, sign
return Expr<T>{std::move(funcRef)};
}
diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp
index bfedc32a33a8bad..82a5cb20db9e409 100644
--- a/flang/lib/Evaluate/fold-logical.cpp
+++ b/flang/lib/Evaluate/fold-logical.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "fold-implementation.h"
+#include "fold-matmul.h"
#include "fold-reduction.h"
#include "flang/Evaluate/check-expression.h"
#include "flang/Runtime/magic-numbers.h"
@@ -231,6 +232,8 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
if (auto *expr{UnwrapExpr<Expr<SomeLogical>>(args[0])}) {
return Fold(context, ConvertToType<T>(std::move(*expr)));
}
+ } else if (name == "matmul") {
+ return FoldMatmul(context, std::move(funcRef));
} else if (name == "out_of_range") {
if (Expr<SomeType> * cx{UnwrapExpr<Expr<SomeType>>(args[0])}) {
auto restorer{context.messages().DiscardMessages()};
@@ -367,7 +370,6 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
name == "__builtin_ieee_support_underflow_control") {
return Expr<T>{true};
}
- // TODO: logical, matmul, parity
return Expr<T>{std::move(funcRef)};
}
diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h
new file mode 100644
index 000000000000000..27b6db1fd8bf025
--- /dev/null
+++ b/flang/lib/Evaluate/fold-matmul.h
@@ -0,0 +1,103 @@
+//===-- lib/Evaluate/fold-matmul.h ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_EVALUATE_FOLD_MATMUL_H_
+#define FORTRAN_EVALUATE_FOLD_MATMUL_H_
+
+#include "fold-implementation.h"
+
+namespace Fortran::evaluate {
+
+template <typename T>
+static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
+ using Element = typename Constant<T>::Element;
+ auto args{funcRef.arguments()};
+ CHECK(args.size() == 2);
+ Folder<T> folder{context};
+ Constant<T> *ma{folder.Folding(args[0])};
+ Constant<T> *mb{folder.Folding(args[1])};
+ if (!ma || !mb) {
+ return Expr<T>{std::move(funcRef)};
+ }
+ CHECK(ma->Rank() >= 1 && ma->Rank() <= 2 && mb->Rank() >= 1 &&
+ mb->Rank() <= 2 && (ma->Rank() == 2 || mb->Rank() == 2));
+ ConstantSubscript commonExtent{ma->shape().back()};
+ if (mb->shape().front() != commonExtent) {
+ context.messages().Say(
+ "Arguments to MATMUL have distinct extents %zd and %zd on their last and first dimensions"_err_en_US,
+ commonExtent, mb->shape().front());
+ return MakeInvalidIntrinsic(std::move(funcRef));
+ }
+ ConstantSubscript rows{ma->Rank() == 1 ? 1 : ma->shape()[0]};
+ ConstantSubscript columns{mb->Rank() == 1 ? 1 : mb->shape()[1]};
+ std::vector<Element> elements;
+ elements.reserve(rows * columns);
+ bool overflow{false};
+ [[maybe_unused]] const auto &rounding{
+ context.targetCharacteristics().roundingMode()};
+ // result(j,k) = SUM(A(j,:) * B(:,k))
+ for (ConstantSubscript ci{0}; ci < columns; ++ci) {
+ for (ConstantSubscript ri{0}; ri < rows; ++ri) {
+ ConstantSubscripts aAt{ma->lbounds()};
+ if (ma->Rank() == 2) {
+ aAt[0] += ri;
+ }
+ ConstantSubscripts bAt{mb->lbounds()};
+ if (mb->Rank() == 2) {
+ bAt[1] += ci;
+ }
+ Element sum{};
+ [[maybe_unused]] Element correction{};
+ for (ConstantSubscript j{0}; j < commonExtent; ++j) {
+ Element aElt{ma->At(aAt)};
+ Element bElt{mb->At(bAt)};
+ if constexpr (T::category == TypeCategory::Real ||
+ T::category == TypeCategory::Complex) {
+ // Kahan summation
+ auto product{aElt.Multiply(bElt, rounding)};
+ overflow |= product.flags.test(RealFlag::Overflow);
+ auto next{correction.Add(product.value, rounding)};
+ overflow |= next.flags.test(RealFlag::Overflow);
+ auto added{sum.Add(next.value, rounding)};
+ overflow |= added.flags.test(RealFlag::Overflow);
+ correction = added.value.Subtract(sum, rounding)
+ .value.Subtract(next.value, rounding)
+ .value;
+ sum = std::move(added.value);
+ } else if constexpr (T::category == TypeCategory::Integer) {
+ auto product{aElt.MultiplySigned(bElt)};
+ overflow |= product.SignedMultiplicationOverflowed();
+ auto added{sum.AddSigned(product.lower)};
+ overflow |= added.overflow;
+ sum = std::move(added.value);
+ } else {
+ static_assert(T::category == TypeCategory::Logical);
+ sum = sum.OR(aElt.AND(bElt));
+ }
+ ++aAt.back();
+ ++bAt.front();
+ }
+ elements.push_back(sum);
+ }
+ }
+ if (overflow) {
+ context.messages().Say(
+ "MATMUL of %s data overflowed during computation"_warn_en_US,
+ T::AsFortran());
+ }
+ ConstantSubscripts shape;
+ if (ma->Rank() == 2) {
+ shape.push_back(rows);
+ }
+ if (mb->Rank() == 2) {
+ shape.push_back(columns);
+ }
+ return Expr<T>{Constant<T>{std::move(elements), std::move(shape)}};
+}
+} // namespace Fortran::evaluate
+#endif // FORTRAN_EVALUATE_FOLD_MATMUL_H_
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 6bcc3ec73982157..6ae069df5d7a425 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "fold-implementation.h"
+#include "fold-matmul.h"
#include "fold-reduction.h"
namespace Fortran::evaluate {
@@ -269,6 +270,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
}
return result.value;
}));
+ } else if (name == "matmul") {
+ return FoldMatmul(context, std::move(funcRef));
} else if (name == "max") {
return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
} else if (name == "maxval") {
@@ -446,7 +449,6 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
return result.value;
}));
}
- // TODO: matmul
return Expr<T>{std::move(funcRef)};
}
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index 0dd55124e6a512e..60c757dc3f4fa8e 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -43,7 +43,7 @@ static Expr<T> FoldDotProduct(
Expr<T> products{Fold(
context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
- Element correction; // Use Kahan summation for greater precision.
+ Element correction{}; // Use Kahan summation for greater precision.
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
auto next{correction.Add(x, rounding)};
@@ -80,7 +80,7 @@ static Expr<T> FoldDotProduct(
Expr<T> products{
Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
- Element correction; // Use Kahan summation for greater precision.
+ Element correction{}; // Use Kahan summation for greater precision.
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
auto next{correction.Add(x, rounding)};
diff --git a/flang/test/Evaluate/fold-matmul.f90 b/flang/test/Evaluate/fold-matmul.f90
new file mode 100644
index 000000000000000..dce90197e1f1fdd
--- /dev/null
+++ b/flang/test/Evaluate/fold-matmul.f90
@@ -0,0 +1,41 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of MATMUL()
+module m
+ integer, parameter :: ia(2,3) = reshape([1, 2, 2, 3, 3, 4], shape(ia))
+ integer, parameter :: ib(3,2) = reshape([1, 2, 3, 2, 3, 4], shape(ib))
+ integer, parameter :: ix(*) = [1, 2]
+ integer, parameter :: iy(*) = [1, 2, 3]
+ integer, parameter :: iab(*,*) = matmul(ia, ib)
+ integer, parameter :: ixa(*) = matmul(ix, ia)
+ integer, parameter :: iay(*) = matmul(ia, iy)
+ logical, parameter :: test_iab = all([iab] == [14, 20, 20, 29])
+ logical, parameter :: test_ixa = all(ixa == [5, 8, 11])
+ logical, parameter :: test_iay = all(iay == [14, 20])
+
+ real, parameter :: ra(*,*) = ia
+ real, parameter :: rb(*,*) = ib
+ real, parameter :: rx(*) = ix
+ real, parameter :: ry(*) = iy
+ real, parameter :: rab(*,*) = matmul(ra, rb)
+ real, parameter :: rxa(*) = matmul(rx, ra)
+ real, parameter :: ray(*) = matmul(ra, ry)
+ logical, parameter :: test_rab = all(rab == iab)
+ logical, parameter :: test_rxa = all(rxa == ixa)
+ logical, parameter :: test_ray = all(ray == iay)
+
+ complex, parameter :: za(*,*) = cmplx(ra, -1.)
+ complex, parameter :: zb(*,*) = cmplx(rb, -1.)
+ complex, parameter :: zx(*) = cmplx(rx, -1.)
+ complex, parameter :: zy(*) = cmplx(ry, -1.)
+ complex, parameter :: zab(*,*) = matmul(za, zb)
+ complex, parameter :: zxa(*) = matmul(zx, za)
+ complex, parameter :: zay(*) = matmul(za, zy)
+ logical, parameter :: test_zab = all([zab] == [(11,-12),(17,-15),(17,-15),(26,-18)])
+ logical, parameter :: test_zxa = all(zxa == [(3,-6),(6,-8),(9,-10)])
+ logical, parameter :: test_zay = all(zay == [(11,-12),(17,-15)])
+
+ logical, parameter :: la(16, 4) = reshape([((iand(shiftr(j,k),1)/=0, j=0,15), k=0,3)], shape(la))
+ logical, parameter :: lb(4, 16) = transpose(la)
+ logical, parameter :: lab(16, 16) = matmul(la, lb)
+ logical, parameter :: test_lab = all([lab] .eqv. [((iand(k,j)/=0, k=0,15), j=0,15)])
+end
``````````
</details>
https://github.com/llvm/llvm-project/pull/72176
More information about the flang-commits
mailing list