[flang-commits] [flang] d48236a - [flang] Fold SPREAD
Peter Klausler via flang-commits
flang-commits at lists.llvm.org
Thu Oct 28 14:10:24 PDT 2021
Author: peter klausler
Date: 2021-10-28T14:10:16-07:00
New Revision: d48236a51c5a9cf372f8f633e538f0e1784a16d4
URL: https://github.com/llvm/llvm-project/commit/d48236a51c5a9cf372f8f633e538f0e1784a16d4
DIFF: https://github.com/llvm/llvm-project/commit/d48236a51c5a9cf372f8f633e538f0e1784a16d4.diff
LOG: [flang] Fold SPREAD
Implements constant folding of the transformational intrinsic
function SPREAD().
Differential Revision: https://reviews.llvm.org/D112739
Added:
flang/test/Evaluate/errors01.f90
flang/test/Evaluate/fold-spread.f90
Modified:
flang/lib/Evaluate/fold-implementation.h
Removed:
flang/test/Evaluate/folding19.f90
################################################################################
diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index c61637263a63..4f885413b396 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -65,6 +65,7 @@ template <typename T> class Folder {
Expr<T> EOSHIFT(FunctionRef<T> &&);
Expr<T> PACK(FunctionRef<T> &&);
Expr<T> RESHAPE(FunctionRef<T> &&);
+ Expr<T> SPREAD(FunctionRef<T> &&);
Expr<T> TRANSPOSE(FunctionRef<T> &&);
Expr<T> UNPACK(FunctionRef<T> &&);
@@ -855,6 +856,51 @@ template <typename T> Expr<T> Folder<T>::RESHAPE(FunctionRef<T> &&funcRef) {
return MakeInvalidIntrinsic(std::move(funcRef));
}
+template <typename T> Expr<T> Folder<T>::SPREAD(FunctionRef<T> &&funcRef) {
+ auto args{funcRef.arguments()};
+ CHECK(args.size() == 3);
+ const Constant<T> *source{UnwrapConstantValue<T>(args[0])};
+ auto dim{GetInt64Arg(args[1])};
+ auto ncopies{GetInt64Arg(args[2])};
+ if (!source || !dim) {
+ return Expr<T>{std::move(funcRef)};
+ }
+ int sourceRank{source->Rank()};
+ if (sourceRank >= common::maxRank) {
+ context_.messages().Say(
+ "SOURCE= argument to SPREAD has rank %d but must have rank less than %d"_err_en_US,
+ sourceRank, common::maxRank);
+ } else if (*dim < 1 || *dim > sourceRank + 1) {
+ context_.messages().Say(
+ "DIM=%d argument to SPREAD must be between 1 and %d"_err_en_US, *dim,
+ sourceRank + 1);
+ } else if (!ncopies) {
+ return Expr<T>{std::move(funcRef)};
+ } else {
+ if (*ncopies < 0) {
+ ncopies = 0;
+ }
+ // TODO: Consider moving this implementation (after the user error
+ // checks), along with other transformational intrinsics, into
+ // constant.h (or a new header) so that the transformationals
+ // are available for all Constant<>s without needing to be packaged
+ // as references to intrinsic functions for folding.
+ ConstantSubscripts shape{source->shape()};
+ shape.insert(shape.begin() + *dim - 1, *ncopies);
+ Constant<T> spread{source->Reshape(std::move(shape))};
+ std::vector<int> dimOrder;
+ for (int j{0}; j < sourceRank; ++j) {
+ dimOrder.push_back(j);
+ }
+ dimOrder.insert(dimOrder.begin() + *dim - 1, sourceRank);
+ ConstantSubscripts at{spread.lbounds()}; // all 1
+ spread.CopyFrom(*source, TotalElementCount(spread.shape()), at, &dimOrder);
+ return Expr<T>{std::move(spread)};
+ }
+ // Invalid, prevent re-folding
+ return MakeInvalidIntrinsic(std::move(funcRef));
+}
+
template <typename T> Expr<T> Folder<T>::TRANSPOSE(FunctionRef<T> &&funcRef) {
auto args{funcRef.arguments()};
CHECK(args.size() == 1);
@@ -1017,12 +1063,13 @@ Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
return Folder<T>{context}.PACK(std::move(funcRef));
} else if (name == "reshape") {
return Folder<T>{context}.RESHAPE(std::move(funcRef));
+ } else if (name == "spread") {
+ return Folder<T>{context}.SPREAD(std::move(funcRef));
} else if (name == "transpose") {
return Folder<T>{context}.TRANSPOSE(std::move(funcRef));
} else if (name == "unpack") {
return Folder<T>{context}.UNPACK(std::move(funcRef));
}
- // TODO: spread
// TODO: extends_type_of, same_type_as
if constexpr (!std::is_same_v<T, SomeDerived>) {
return FoldIntrinsicFunction(context, std::move(funcRef));
diff --git a/flang/test/Evaluate/folding19.f90 b/flang/test/Evaluate/errors01.f90
similarity index 88%
rename from flang/test/Evaluate/folding19.f90
rename to flang/test/Evaluate/errors01.f90
index 32d4be7e01db..2ba5444da931 100644
--- a/flang/test/Evaluate/folding19.f90
+++ b/flang/test/Evaluate/errors01.f90
@@ -90,4 +90,14 @@ subroutine s8
!CHECK: error: SHIFT=65 count for shiftl is greater than 64
integer(8), parameter :: bad6 = shiftl(1_8, 65)
end subroutine
+ subroutine s9
+ integer, parameter :: rank15(1,1,1,1,1,1,1,1,1,1,1,1,1,1,1) = 1
+ !CHECK: error: SOURCE= argument to SPREAD has rank 15 but must have rank less than 15
+ integer, parameter :: bad1 = spread(rank15, 1, 1)
+ integer, parameter :: matrix(2, 2) = reshape([1, 2, 3, 4], [2, 2])
+ !CHECK: error: DIM=0 argument to SPREAD must be between 1 and 3
+ integer, parameter :: bad2 = spread(matrix, 0, 1)
+ !CHECK: error: DIM=4 argument to SPREAD must be between 1 and 3
+ integer, parameter :: bad3 = spread(matrix, 4, 1)
+ end subroutine
end module
diff --git a/flang/test/Evaluate/fold-spread.f90 b/flang/test/Evaluate/fold-spread.f90
new file mode 100644
index 000000000000..127de8fbbe6a
--- /dev/null
+++ b/flang/test/Evaluate/fold-spread.f90
@@ -0,0 +1,13 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of SPREAD
+module m1
+ logical, parameter :: test_empty = size(spread(1, 1, 0)) == 0
+ logical, parameter :: test_stov = all(spread(1, 1, 2) == [1, 1])
+ logical, parameter :: test_vtom1 = all(spread([1, 2], 1, 3) == reshape([1, 1, 1, 2, 2, 2], [3, 2]))
+ logical, parameter :: test_vtom2 = all(spread([1, 2], 2, 3) == reshape([1, 2, 1, 2, 1, 2], [2, 3]))
+ logical, parameter :: test_vtom3 = all(spread([1, 2], 2, 3) == reshape([1, 2, 1, 2, 1, 2], [2, 3]))
+ logical, parameter :: test_log1 = all(all(spread([.false., .true.], 1, 2), dim=2) .eqv. [.false., .false.])
+ logical, parameter :: test_log2 = all(all(spread([.false., .true.], 2, 2), dim=2) .eqv. [.false., .true.])
+ logical, parameter :: test_log3 = all(any(spread([.false., .true.], 1, 2), dim=2) .eqv. [.true., .true.])
+ logical, parameter :: test_log4 = all(any(spread([.false., .true.], 2, 2), dim=2) .eqv. [.false., .true.])
+end module
More information about the flang-commits
mailing list