[flang-commits] [flang] d48236a - [flang] Fold SPREAD

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Thu Oct 28 14:10:24 PDT 2021


Author: peter klausler
Date: 2021-10-28T14:10:16-07:00
New Revision: d48236a51c5a9cf372f8f633e538f0e1784a16d4

URL: https://github.com/llvm/llvm-project/commit/d48236a51c5a9cf372f8f633e538f0e1784a16d4
DIFF: https://github.com/llvm/llvm-project/commit/d48236a51c5a9cf372f8f633e538f0e1784a16d4.diff

LOG: [flang] Fold SPREAD

Implements constant folding of the transformational intrinsic
function SPREAD().

Differential Revision: https://reviews.llvm.org/D112739

Added: 
    flang/test/Evaluate/errors01.f90
    flang/test/Evaluate/fold-spread.f90

Modified: 
    flang/lib/Evaluate/fold-implementation.h

Removed: 
    flang/test/Evaluate/folding19.f90


################################################################################
diff  --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index c61637263a63..4f885413b396 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -65,6 +65,7 @@ template <typename T> class Folder {
   Expr<T> EOSHIFT(FunctionRef<T> &&);
   Expr<T> PACK(FunctionRef<T> &&);
   Expr<T> RESHAPE(FunctionRef<T> &&);
+  Expr<T> SPREAD(FunctionRef<T> &&);
   Expr<T> TRANSPOSE(FunctionRef<T> &&);
   Expr<T> UNPACK(FunctionRef<T> &&);
 
@@ -855,6 +856,51 @@ template <typename T> Expr<T> Folder<T>::RESHAPE(FunctionRef<T> &&funcRef) {
   return MakeInvalidIntrinsic(std::move(funcRef));
 }
 
+template <typename T> Expr<T> Folder<T>::SPREAD(FunctionRef<T> &&funcRef) {
+  auto args{funcRef.arguments()};
+  CHECK(args.size() == 3);
+  const Constant<T> *source{UnwrapConstantValue<T>(args[0])};
+  auto dim{GetInt64Arg(args[1])};
+  auto ncopies{GetInt64Arg(args[2])};
+  if (!source || !dim) {
+    return Expr<T>{std::move(funcRef)};
+  }
+  int sourceRank{source->Rank()};
+  if (sourceRank >= common::maxRank) {
+    context_.messages().Say(
+        "SOURCE= argument to SPREAD has rank %d but must have rank less than %d"_err_en_US,
+        sourceRank, common::maxRank);
+  } else if (*dim < 1 || *dim > sourceRank + 1) {
+    context_.messages().Say(
+        "DIM=%d argument to SPREAD must be between 1 and %d"_err_en_US, *dim,
+        sourceRank + 1);
+  } else if (!ncopies) {
+    return Expr<T>{std::move(funcRef)};
+  } else {
+    if (*ncopies < 0) {
+      ncopies = 0;
+    }
+    // TODO: Consider moving this implementation (after the user error
+    // checks), along with other transformational intrinsics, into
+    // constant.h (or a new header) so that the transformationals
+    // are available for all Constant<>s without needing to be packaged
+    // as references to intrinsic functions for folding.
+    ConstantSubscripts shape{source->shape()};
+    shape.insert(shape.begin() + *dim - 1, *ncopies);
+    Constant<T> spread{source->Reshape(std::move(shape))};
+    std::vector<int> dimOrder;
+    for (int j{0}; j < sourceRank; ++j) {
+      dimOrder.push_back(j);
+    }
+    dimOrder.insert(dimOrder.begin() + *dim - 1, sourceRank);
+    ConstantSubscripts at{spread.lbounds()}; // all 1
+    spread.CopyFrom(*source, TotalElementCount(spread.shape()), at, &dimOrder);
+    return Expr<T>{std::move(spread)};
+  }
+  // Invalid, prevent re-folding
+  return MakeInvalidIntrinsic(std::move(funcRef));
+}
+
 template <typename T> Expr<T> Folder<T>::TRANSPOSE(FunctionRef<T> &&funcRef) {
   auto args{funcRef.arguments()};
   CHECK(args.size() == 1);
@@ -1017,12 +1063,13 @@ Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
       return Folder<T>{context}.PACK(std::move(funcRef));
     } else if (name == "reshape") {
       return Folder<T>{context}.RESHAPE(std::move(funcRef));
+    } else if (name == "spread") {
+      return Folder<T>{context}.SPREAD(std::move(funcRef));
     } else if (name == "transpose") {
       return Folder<T>{context}.TRANSPOSE(std::move(funcRef));
     } else if (name == "unpack") {
       return Folder<T>{context}.UNPACK(std::move(funcRef));
     }
-    // TODO: spread
     // TODO: extends_type_of, same_type_as
     if constexpr (!std::is_same_v<T, SomeDerived>) {
       return FoldIntrinsicFunction(context, std::move(funcRef));

diff  --git a/flang/test/Evaluate/folding19.f90 b/flang/test/Evaluate/errors01.f90
similarity index 88%
rename from flang/test/Evaluate/folding19.f90
rename to flang/test/Evaluate/errors01.f90
index 32d4be7e01db..2ba5444da931 100644
--- a/flang/test/Evaluate/folding19.f90
+++ b/flang/test/Evaluate/errors01.f90
@@ -90,4 +90,14 @@ subroutine s8
     !CHECK: error: SHIFT=65 count for shiftl is greater than 64
     integer(8), parameter :: bad6 = shiftl(1_8, 65)
   end subroutine
+  subroutine s9
+    integer, parameter :: rank15(1,1,1,1,1,1,1,1,1,1,1,1,1,1,1) = 1
+    !CHECK: error: SOURCE= argument to SPREAD has rank 15 but must have rank less than 15
+    integer, parameter :: bad1 = spread(rank15, 1, 1)
+    integer, parameter :: matrix(2, 2) = reshape([1, 2, 3, 4], [2, 2])
+    !CHECK: error: DIM=0 argument to SPREAD must be between 1 and 3
+    integer, parameter :: bad2 = spread(matrix, 0, 1)
+    !CHECK: error: DIM=4 argument to SPREAD must be between 1 and 3
+    integer, parameter :: bad3 = spread(matrix, 4, 1)
+  end subroutine
 end module

diff  --git a/flang/test/Evaluate/fold-spread.f90 b/flang/test/Evaluate/fold-spread.f90
new file mode 100644
index 000000000000..127de8fbbe6a
--- /dev/null
+++ b/flang/test/Evaluate/fold-spread.f90
@@ -0,0 +1,13 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of SPREAD
+module m1
+  logical, parameter :: test_empty = size(spread(1, 1, 0)) == 0
+  logical, parameter :: test_stov = all(spread(1, 1, 2) == [1, 1])
+  logical, parameter :: test_vtom1 = all(spread([1, 2], 1, 3) == reshape([1, 1, 1, 2, 2, 2], [3, 2]))
+  logical, parameter :: test_vtom2 = all(spread([1, 2], 2, 3) == reshape([1, 2, 1, 2, 1, 2], [2, 3]))
+  logical, parameter :: test_vtom3 = all(spread([1, 2], 2, 3) == reshape([1, 2, 1, 2, 1, 2], [2, 3]))
+  logical, parameter :: test_log1 = all(all(spread([.false., .true.], 1, 2), dim=2) .eqv. [.false., .false.])
+  logical, parameter :: test_log2 = all(all(spread([.false., .true.], 2, 2), dim=2) .eqv. [.false., .true.])
+  logical, parameter :: test_log3 = all(any(spread([.false., .true.], 1, 2), dim=2) .eqv. [.true., .true.])
+  logical, parameter :: test_log4 = all(any(spread([.false., .true.], 2, 2), dim=2) .eqv. [.false., .true.])
+end module


        


More information about the flang-commits mailing list