[flang-commits] [flang] 8926f0f - [flang] Fold MERGE() of derived type values
Peter Klausler via flang-commits
flang-commits at lists.llvm.org
Tue Aug 8 12:10:28 PDT 2023
Author: Peter Klausler
Date: 2023-08-08T12:02:31-07:00
New Revision: 8926f0fe62a55fc0de7d6839700513328bc0e13f
URL: https://github.com/llvm/llvm-project/commit/8926f0fe62a55fc0de7d6839700513328bc0e13f
DIFF: https://github.com/llvm/llvm-project/commit/8926f0fe62a55fc0de7d6839700513328bc0e13f.diff
LOG: [flang] Fold MERGE() of derived type values
Generalize FoldMerge() to accommodate derived type arguments and results,
rename it into Folder<T>::MERGE(), and remove it from the various
FoldIntrinsicFunction() routines for intrinsic types.
Fixes llvm-test-suite/Fortran/gfortran/regression/merge_init_expr_2.f90.
Differential Revision: https://reviews.llvm.org/D157345
Added:
flang/test/Evaluate/fold-merge.f90
Modified:
flang/lib/Evaluate/fold-character.cpp
flang/lib/Evaluate/fold-complex.cpp
flang/lib/Evaluate/fold-implementation.h
flang/lib/Evaluate/fold-integer.cpp
flang/lib/Evaluate/fold-logical.cpp
flang/lib/Evaluate/fold-real.cpp
Removed:
################################################################################
diff --git a/flang/lib/Evaluate/fold-character.cpp b/flang/lib/Evaluate/fold-character.cpp
index 2a55334866aa8a..a599815fa7aee9 100644
--- a/flang/lib/Evaluate/fold-character.cpp
+++ b/flang/lib/Evaluate/fold-character.cpp
@@ -80,8 +80,6 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
return FoldMaxvalMinval<T>(
context, std::move(funcRef), RelationalOperator::GT, *identity);
}
- } else if (name == "merge") {
- return FoldMerge<T>(context, std::move(funcRef));
} else if (name == "min") {
return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
} else if (name == "minval") {
diff --git a/flang/lib/Evaluate/fold-complex.cpp b/flang/lib/Evaluate/fold-complex.cpp
index efdb18f8891322..520121ad254de7 100644
--- a/flang/lib/Evaluate/fold-complex.cpp
+++ b/flang/lib/Evaluate/fold-complex.cpp
@@ -64,8 +64,6 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
}
} else if (name == "dot_product") {
return FoldDotProduct<T>(context, std::move(funcRef));
- } else if (name == "merge") {
- return FoldMerge<T>(context, std::move(funcRef));
} else if (name == "product") {
auto one{Scalar<Part>::FromInteger(value::Integer<8>{1}).value};
return FoldProduct<T>(context, std::move(funcRef), Scalar<T>{one});
diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index aaa13ec371753c..c47a22c99a4577 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -64,6 +64,7 @@ template <typename T> class Folder {
Expr<T> CSHIFT(FunctionRef<T> &&);
Expr<T> EOSHIFT(FunctionRef<T> &&);
+ Expr<T> MERGE(FunctionRef<T> &&);
Expr<T> PACK(FunctionRef<T> &&);
Expr<T> RESHAPE(FunctionRef<T> &&);
Expr<T> SPREAD(FunctionRef<T> &&);
@@ -397,9 +398,11 @@ template <typename T> Expr<T> Folder<T>::Folding(Designator<T> &&designator) {
template <typename T>
Constant<T> *Folder<T>::Folding(std::optional<ActualArgument> &arg) {
if (auto *expr{UnwrapExpr<Expr<SomeType>>(arg)}) {
- if (!UnwrapExpr<Expr<T>>(*expr)) {
- if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) {
- *expr = Fold(context_, std::move(*converted));
+ if constexpr (T::category != TypeCategory::Derived) {
+ if (!UnwrapExpr<Expr<T>>(*expr)) {
+ if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) {
+ *expr = Fold(context_, std::move(*converted));
+ }
}
}
return UnwrapConstantValue<T>(*expr);
@@ -411,8 +414,6 @@ template <typename... A, std::size_t... I>
std::optional<std::tuple<const Constant<A> *...>> GetConstantArgumentsHelper(
FoldingContext &context, ActualArguments &arguments,
std::index_sequence<I...>) {
- static_assert(
- (... && IsSpecificIntrinsicType<A>)); // TODO derived types for MERGE?
static_assert(sizeof...(A) > 0);
std::tuple<const Constant<A> *...> args{
Folder<A>{context}.Folding(arguments.at(I))...};
@@ -489,7 +490,6 @@ Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
}
}
CHECK(rank == GetRank(shape));
-
// Compute all the scalar values of the results
std::vector<Scalar<TR>> results;
if (TotalElementCount(shape) > 0) {
@@ -513,6 +513,13 @@ Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
auto len{static_cast<ConstantSubscript>(
results.empty() ? 0 : results[0].length())};
return Expr<TR>{Constant<TR>{len, std::move(results), std::move(shape)}};
+ } else if constexpr (TR::category == TypeCategory::Derived) {
+ if (!results.empty()) {
+ return Expr<TR>{rank == 0
+ ? Constant<TR>{results.front()}
+ : Constant<TR>{results.front().derivedTypeSpec(),
+ std::move(results), std::move(shape)}};
+ }
} else {
return Expr<TR>{Constant<TR>{std::move(results), std::move(shape)}};
}
@@ -780,6 +787,16 @@ template <typename T> Expr<T> Folder<T>::EOSHIFT(FunctionRef<T> &&funcRef) {
return MakeInvalidIntrinsic(std::move(funcRef));
}
+template <typename T> Expr<T> Folder<T>::MERGE(FunctionRef<T> &&funcRef) {
+ return FoldElementalIntrinsic<T, T, T, LogicalResult>(context_,
+ std::move(funcRef),
+ ScalarFunc<T, T, T, LogicalResult>(
+ [](const Scalar<T> &ifTrue, const Scalar<T> &ifFalse,
+ const Scalar<LogicalResult> &predicate) -> Scalar<T> {
+ return predicate.IsTrue() ? ifTrue : ifFalse;
+ }));
+}
+
template <typename T> Expr<T> Folder<T>::PACK(FunctionRef<T> &&funcRef) {
auto args{funcRef.arguments()};
CHECK(args.size() == 3);
@@ -1126,6 +1143,8 @@ Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
return Folder<T>{context}.CSHIFT(std::move(funcRef));
} else if (name == "eoshift") {
return Folder<T>{context}.EOSHIFT(std::move(funcRef));
+ } else if (name == "merge") {
+ return Folder<T>{context}.MERGE(std::move(funcRef));
} else if (name == "pack") {
return Folder<T>{context}.PACK(std::move(funcRef));
} else if (name == "reshape") {
@@ -1147,17 +1166,6 @@ Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
return Expr<T>{std::move(funcRef)};
}
-template <typename T>
-Expr<T> FoldMerge(FoldingContext &context, FunctionRef<T> &&funcRef) {
- return FoldElementalIntrinsic<T, T, T, LogicalResult>(context,
- std::move(funcRef),
- ScalarFunc<T, T, T, LogicalResult>(
- [](const Scalar<T> &ifTrue, const Scalar<T> &ifFalse,
- const Scalar<LogicalResult> &predicate) -> Scalar<T> {
- return predicate.IsTrue() ? ifTrue : ifFalse;
- }));
-}
-
Expr<ImpliedDoIndex::Result> FoldOperation(FoldingContext &, ImpliedDoIndex &&);
// Array constructor folding
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index 02d5ea5a133ada..53659d2c36d7c2 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -1038,8 +1038,6 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
} else if (name == "maxval") {
return FoldMaxvalMinval<T>(context, std::move(funcRef),
RelationalOperator::GT, T::Scalar::Least());
- } else if (name == "merge") {
- return FoldMerge<T>(context, std::move(funcRef));
} else if (name == "merge_bits") {
return FoldElementalIntrinsic<T, T, T, T>(
context, std::move(funcRef), &Scalar<T>::MERGE_BITS);
diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp
index 129a8fc40577d0..0803c868368119 100644
--- a/flang/lib/Evaluate/fold-logical.cpp
+++ b/flang/lib/Evaluate/fold-logical.cpp
@@ -215,8 +215,6 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
if (auto *expr{UnwrapExpr<Expr<SomeLogical>>(args[0])}) {
return Fold(context, ConvertToType<T>(std::move(*expr)));
}
- } else if (name == "merge") {
- return FoldMerge<T>(context, std::move(funcRef));
} else if (name == "parity") {
return FoldAllAnyParity(
context, std::move(funcRef), &Scalar<T>::NEQV, Scalar<T>{false});
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 01a97951b0412a..671d897ef7b2f8 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -184,8 +184,6 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
} else if (name == "maxval") {
return FoldMaxvalMinval<T>(context, std::move(funcRef),
RelationalOperator::GT, T::Scalar::HUGE().Negate());
- } else if (name == "merge") {
- return FoldMerge<T>(context, std::move(funcRef));
} else if (name == "min") {
return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
} else if (name == "minval") {
diff --git a/flang/test/Evaluate/fold-merge.f90 b/flang/test/Evaluate/fold-merge.f90
new file mode 100644
index 00000000000000..9cbd0ca7f2a99c
--- /dev/null
+++ b/flang/test/Evaluate/fold-merge.f90
@@ -0,0 +1,22 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of MERGE
+module m
+ type t
+ integer n
+ end type
+ logical, parameter :: test_01 = all(merge([1,2,3],4,[.true.,.false.,.true.]) == [1,4,3])
+ logical, parameter :: test_02 = all(merge([1,2,3],4,.true.) == [1,2,3])
+ logical, parameter :: test_03 = all(merge([1,2,3],4,.false.) == [4,4,4])
+ logical, parameter :: test_04 = all(merge(1,4,[.true.,.false.,.true.,.false.]) == [1,4,1,4])
+ type(t), parameter :: dt00a = merge(t(1),t(2),.true.)
+ logical, parameter :: test_05 = dt00a%n == 1
+ type(t), parameter :: dt00b = merge(t(1),t(2),.false.)
+ logical, parameter :: test_06 = dt00b%n == 2
+ type(t), parameter :: dt01(*) = merge([t(1),t(2)],[t(3),t(4)],[.false.,.true.])
+ logical, parameter :: test_07 = all(dt01%n == [3,2])
+ type(t), parameter :: dt02(*) = merge(t(1),[t(3),t(4)],.true.)
+ logical, parameter :: test_08 = all(dt02%n == [1,1])
+ type(t), parameter :: dt03(*) = merge([t(1),t(2)],t(3),[.true.,.false.])
+ logical, parameter :: test_09 = all(dt03%n == [1,3])
+ logical, parameter :: test_10 = merge('ab','cd',.true.) == 'ab'
+end
More information about the flang-commits
mailing list