[flang-commits] [flang] [flang] Fold MATMUL() (PR #72176)

via flang-commits flang-commits at lists.llvm.org
Mon Nov 13 16:31:11 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-semantics

Author: Peter Klausler (klausler)

<details>
<summary>Changes</summary>

Implements constant folding for matrix multiplication for all four accepted type categories.

---
Full diff: https://github.com/llvm/llvm-project/pull/72176.diff


7 Files Affected:

- (modified) flang/lib/Evaluate/fold-complex.cpp (+3-1) 
- (modified) flang/lib/Evaluate/fold-integer.cpp (+3-1) 
- (modified) flang/lib/Evaluate/fold-logical.cpp (+3-1) 
- (added) flang/lib/Evaluate/fold-matmul.h (+103) 
- (modified) flang/lib/Evaluate/fold-real.cpp (+3-1) 
- (modified) flang/lib/Evaluate/fold-reduction.h (+2-2) 
- (added) flang/test/Evaluate/fold-matmul.f90 (+41) 


``````````diff
diff --git a/flang/lib/Evaluate/fold-complex.cpp b/flang/lib/Evaluate/fold-complex.cpp
index e40e3a37df14948..3260f82ffe8d734 100644
--- a/flang/lib/Evaluate/fold-complex.cpp
+++ b/flang/lib/Evaluate/fold-complex.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "fold-implementation.h"
+#include "fold-matmul.h"
 #include "fold-reduction.h"
 
 namespace Fortran::evaluate {
@@ -64,13 +65,14 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
     }
   } else if (name == "dot_product") {
     return FoldDotProduct<T>(context, std::move(funcRef));
+  } else if (name == "matmul") {
+    return FoldMatmul(context, std::move(funcRef));
   } else if (name == "product") {
     auto one{Scalar<Part>::FromInteger(value::Integer<8>{1}).value};
     return FoldProduct<T>(context, std::move(funcRef), Scalar<T>{one});
   } else if (name == "sum") {
     return FoldSum<T>(context, std::move(funcRef));
   }
-  // TODO: matmul
   return Expr<T>{std::move(funcRef)};
 }
 
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index dedfc20a491cd88..2882369105f6626 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "fold-implementation.h"
+#include "fold-matmul.h"
 #include "fold-reduction.h"
 #include "flang/Evaluate/check-expression.h"
 
@@ -1042,6 +1043,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
         ScalarFunc<T, Int4>([&fptr](const Scalar<Int4> &places) -> Scalar<T> {
           return fptr(static_cast<int>(places.ToInt64()));
         }));
+  } else if (name == "matmul") {
+    return FoldMatmul(context, std::move(funcRef));
   } else if (name == "max") {
     return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
   } else if (name == "max0" || name == "max1") {
@@ -1279,7 +1282,6 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
   } else if (name == "ubound") {
     return UBOUND(context, std::move(funcRef));
   }
-  // TODO: dot_product, matmul, sign
   return Expr<T>{std::move(funcRef)};
 }
 
diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp
index bfedc32a33a8bad..82a5cb20db9e409 100644
--- a/flang/lib/Evaluate/fold-logical.cpp
+++ b/flang/lib/Evaluate/fold-logical.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "fold-implementation.h"
+#include "fold-matmul.h"
 #include "fold-reduction.h"
 #include "flang/Evaluate/check-expression.h"
 #include "flang/Runtime/magic-numbers.h"
@@ -231,6 +232,8 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
     if (auto *expr{UnwrapExpr<Expr<SomeLogical>>(args[0])}) {
       return Fold(context, ConvertToType<T>(std::move(*expr)));
     }
+  } else if (name == "matmul") {
+    return FoldMatmul(context, std::move(funcRef));
   } else if (name == "out_of_range") {
     if (Expr<SomeType> * cx{UnwrapExpr<Expr<SomeType>>(args[0])}) {
       auto restorer{context.messages().DiscardMessages()};
@@ -367,7 +370,6 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
       name == "__builtin_ieee_support_underflow_control") {
     return Expr<T>{true};
   }
-  // TODO: logical, matmul, parity
   return Expr<T>{std::move(funcRef)};
 }
 
diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h
new file mode 100644
index 000000000000000..27b6db1fd8bf025
--- /dev/null
+++ b/flang/lib/Evaluate/fold-matmul.h
@@ -0,0 +1,103 @@
+//===-- lib/Evaluate/fold-matmul.h ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_EVALUATE_FOLD_MATMUL_H_
+#define FORTRAN_EVALUATE_FOLD_MATMUL_H_
+
+#include "fold-implementation.h"
+
+namespace Fortran::evaluate {
+
+template <typename T>
+static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
+  using Element = typename Constant<T>::Element;
+  auto args{funcRef.arguments()};
+  CHECK(args.size() == 2);
+  Folder<T> folder{context};
+  Constant<T> *ma{folder.Folding(args[0])};
+  Constant<T> *mb{folder.Folding(args[1])};
+  if (!ma || !mb) {
+    return Expr<T>{std::move(funcRef)};
+  }
+  CHECK(ma->Rank() >= 1 && ma->Rank() <= 2 && mb->Rank() >= 1 &&
+      mb->Rank() <= 2 && (ma->Rank() == 2 || mb->Rank() == 2));
+  ConstantSubscript commonExtent{ma->shape().back()};
+  if (mb->shape().front() != commonExtent) {
+    context.messages().Say(
+        "Arguments to MATMUL have distinct extents %zd and %zd on their last and first dimensions"_err_en_US,
+        commonExtent, mb->shape().front());
+    return MakeInvalidIntrinsic(std::move(funcRef));
+  }
+  ConstantSubscript rows{ma->Rank() == 1 ? 1 : ma->shape()[0]};
+  ConstantSubscript columns{mb->Rank() == 1 ? 1 : mb->shape()[1]};
+  std::vector<Element> elements;
+  elements.reserve(rows * columns);
+  bool overflow{false};
+  [[maybe_unused]] const auto &rounding{
+      context.targetCharacteristics().roundingMode()};
+  // result(j,k) = SUM(A(j,:) * B(:,k))
+  for (ConstantSubscript ci{0}; ci < columns; ++ci) {
+    for (ConstantSubscript ri{0}; ri < rows; ++ri) {
+      ConstantSubscripts aAt{ma->lbounds()};
+      if (ma->Rank() == 2) {
+        aAt[0] += ri;
+      }
+      ConstantSubscripts bAt{mb->lbounds()};
+      if (mb->Rank() == 2) {
+        bAt[1] += ci;
+      }
+      Element sum{};
+      [[maybe_unused]] Element correction{};
+      for (ConstantSubscript j{0}; j < commonExtent; ++j) {
+        Element aElt{ma->At(aAt)};
+        Element bElt{mb->At(bAt)};
+        if constexpr (T::category == TypeCategory::Real ||
+            T::category == TypeCategory::Complex) {
+          // Kahan summation
+          auto product{aElt.Multiply(bElt, rounding)};
+          overflow |= product.flags.test(RealFlag::Overflow);
+          auto next{correction.Add(product.value, rounding)};
+          overflow |= next.flags.test(RealFlag::Overflow);
+          auto added{sum.Add(next.value, rounding)};
+          overflow |= added.flags.test(RealFlag::Overflow);
+          correction = added.value.Subtract(sum, rounding)
+                           .value.Subtract(next.value, rounding)
+                           .value;
+          sum = std::move(added.value);
+        } else if constexpr (T::category == TypeCategory::Integer) {
+          auto product{aElt.MultiplySigned(bElt)};
+          overflow |= product.SignedMultiplicationOverflowed();
+          auto added{sum.AddSigned(product.lower)};
+          overflow |= added.overflow;
+          sum = std::move(added.value);
+        } else {
+          static_assert(T::category == TypeCategory::Logical);
+          sum = sum.OR(aElt.AND(bElt));
+        }
+        ++aAt.back();
+        ++bAt.front();
+      }
+      elements.push_back(sum);
+    }
+  }
+  if (overflow) {
+    context.messages().Say(
+        "MATMUL of %s data overflowed during computation"_warn_en_US,
+        T::AsFortran());
+  }
+  ConstantSubscripts shape;
+  if (ma->Rank() == 2) {
+    shape.push_back(rows);
+  }
+  if (mb->Rank() == 2) {
+    shape.push_back(columns);
+  }
+  return Expr<T>{Constant<T>{std::move(elements), std::move(shape)}};
+}
+} // namespace Fortran::evaluate
+#endif // FORTRAN_EVALUATE_FOLD_MATMUL_H_
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 6bcc3ec73982157..6ae069df5d7a425 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "fold-implementation.h"
+#include "fold-matmul.h"
 #include "fold-reduction.h"
 
 namespace Fortran::evaluate {
@@ -269,6 +270,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
               }
               return result.value;
             }));
+  } else if (name == "matmul") {
+    return FoldMatmul(context, std::move(funcRef));
   } else if (name == "max") {
     return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
   } else if (name == "maxval") {
@@ -446,7 +449,6 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
           return result.value;
         }));
   }
-  // TODO: matmul
   return Expr<T>{std::move(funcRef)};
 }
 
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index 0dd55124e6a512e..60c757dc3f4fa8e 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -43,7 +43,7 @@ static Expr<T> FoldDotProduct(
       Expr<T> products{Fold(
           context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
       Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
-      Element correction; // Use Kahan summation for greater precision.
+      Element correction{}; // Use Kahan summation for greater precision.
       const auto &rounding{context.targetCharacteristics().roundingMode()};
       for (const Element &x : cProducts.values()) {
         auto next{correction.Add(x, rounding)};
@@ -80,7 +80,7 @@ static Expr<T> FoldDotProduct(
       Expr<T> products{
           Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
       Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
-      Element correction; // Use Kahan summation for greater precision.
+      Element correction{}; // Use Kahan summation for greater precision.
       const auto &rounding{context.targetCharacteristics().roundingMode()};
       for (const Element &x : cProducts.values()) {
         auto next{correction.Add(x, rounding)};
diff --git a/flang/test/Evaluate/fold-matmul.f90 b/flang/test/Evaluate/fold-matmul.f90
new file mode 100644
index 000000000000000..dce90197e1f1fdd
--- /dev/null
+++ b/flang/test/Evaluate/fold-matmul.f90
@@ -0,0 +1,41 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of MATMUL()
+module m
+  integer, parameter :: ia(2,3) = reshape([1, 2, 2, 3, 3, 4], shape(ia))
+  integer, parameter :: ib(3,2) = reshape([1, 2, 3, 2, 3, 4], shape(ib))
+  integer, parameter :: ix(*) = [1, 2]
+  integer, parameter :: iy(*) = [1, 2, 3]
+  integer, parameter :: iab(*,*) = matmul(ia, ib)
+  integer, parameter :: ixa(*) = matmul(ix, ia)
+  integer, parameter :: iay(*) = matmul(ia, iy)
+  logical, parameter :: test_iab = all([iab] == [14, 20, 20, 29])
+  logical, parameter :: test_ixa = all(ixa == [5, 8, 11])
+  logical, parameter :: test_iay = all(iay == [14, 20])
+
+  real, parameter :: ra(*,*) = ia
+  real, parameter :: rb(*,*) = ib
+  real, parameter :: rx(*) = ix
+  real, parameter :: ry(*) = iy
+  real, parameter :: rab(*,*) = matmul(ra, rb)
+  real, parameter :: rxa(*) = matmul(rx, ra)
+  real, parameter :: ray(*) = matmul(ra, ry)
+  logical, parameter :: test_rab = all(rab == iab)
+  logical, parameter :: test_rxa = all(rxa == ixa)
+  logical, parameter :: test_ray = all(ray == iay)
+
+  complex, parameter :: za(*,*) = cmplx(ra, -1.)
+  complex, parameter :: zb(*,*) = cmplx(rb, -1.)
+  complex, parameter :: zx(*) = cmplx(rx, -1.)
+  complex, parameter :: zy(*) = cmplx(ry, -1.)
+  complex, parameter :: zab(*,*) = matmul(za, zb)
+  complex, parameter :: zxa(*) = matmul(zx, za)
+  complex, parameter :: zay(*) = matmul(za, zy)
+  logical, parameter :: test_zab = all([zab] == [(11,-12),(17,-15),(17,-15),(26,-18)])
+  logical, parameter :: test_zxa = all(zxa == [(3,-6),(6,-8),(9,-10)])
+  logical, parameter :: test_zay = all(zay == [(11,-12),(17,-15)])
+
+  logical, parameter :: la(16, 4) = reshape([((iand(shiftr(j,k),1)/=0, j=0,15), k=0,3)], shape(la))
+  logical, parameter :: lb(4, 16) = transpose(la)
+  logical, parameter :: lab(16, 16) = matmul(la, lb)
+  logical, parameter :: test_lab = all([lab] .eqv. [((iand(k,j)/=0, k=0,15), j=0,15)])
+end

``````````

</details>


https://github.com/llvm/llvm-project/pull/72176


More information about the flang-commits mailing list