[flang-commits] [flang] 8926f0f - [flang] Fold MERGE() of derived type values

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Tue Aug 8 12:10:28 PDT 2023


Author: Peter Klausler
Date: 2023-08-08T12:02:31-07:00
New Revision: 8926f0fe62a55fc0de7d6839700513328bc0e13f

URL: https://github.com/llvm/llvm-project/commit/8926f0fe62a55fc0de7d6839700513328bc0e13f
DIFF: https://github.com/llvm/llvm-project/commit/8926f0fe62a55fc0de7d6839700513328bc0e13f.diff

LOG: [flang] Fold MERGE() of derived type values

Generalize FoldMerge() to accommodate derived type arguments and results,
rename it into Folder<T>::MERGE(), and remove it from the various
FoldIntrinsicFunction() routines for intrinsic types.

Fixes llvm-test-suite/Fortran/gfortran/regression/merge_init_expr_2.f90.

Differential Revision: https://reviews.llvm.org/D157345

Added: 
    flang/test/Evaluate/fold-merge.f90

Modified: 
    flang/lib/Evaluate/fold-character.cpp
    flang/lib/Evaluate/fold-complex.cpp
    flang/lib/Evaluate/fold-implementation.h
    flang/lib/Evaluate/fold-integer.cpp
    flang/lib/Evaluate/fold-logical.cpp
    flang/lib/Evaluate/fold-real.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Evaluate/fold-character.cpp b/flang/lib/Evaluate/fold-character.cpp
index 2a55334866aa8a..a599815fa7aee9 100644
--- a/flang/lib/Evaluate/fold-character.cpp
+++ b/flang/lib/Evaluate/fold-character.cpp
@@ -80,8 +80,6 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
       return FoldMaxvalMinval<T>(
           context, std::move(funcRef), RelationalOperator::GT, *identity);
     }
-  } else if (name == "merge") {
-    return FoldMerge<T>(context, std::move(funcRef));
   } else if (name == "min") {
     return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
   } else if (name == "minval") {

diff  --git a/flang/lib/Evaluate/fold-complex.cpp b/flang/lib/Evaluate/fold-complex.cpp
index efdb18f8891322..520121ad254de7 100644
--- a/flang/lib/Evaluate/fold-complex.cpp
+++ b/flang/lib/Evaluate/fold-complex.cpp
@@ -64,8 +64,6 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
     }
   } else if (name == "dot_product") {
     return FoldDotProduct<T>(context, std::move(funcRef));
-  } else if (name == "merge") {
-    return FoldMerge<T>(context, std::move(funcRef));
   } else if (name == "product") {
     auto one{Scalar<Part>::FromInteger(value::Integer<8>{1}).value};
     return FoldProduct<T>(context, std::move(funcRef), Scalar<T>{one});

diff  --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index aaa13ec371753c..c47a22c99a4577 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -64,6 +64,7 @@ template <typename T> class Folder {
 
   Expr<T> CSHIFT(FunctionRef<T> &&);
   Expr<T> EOSHIFT(FunctionRef<T> &&);
+  Expr<T> MERGE(FunctionRef<T> &&);
   Expr<T> PACK(FunctionRef<T> &&);
   Expr<T> RESHAPE(FunctionRef<T> &&);
   Expr<T> SPREAD(FunctionRef<T> &&);
@@ -397,9 +398,11 @@ template <typename T> Expr<T> Folder<T>::Folding(Designator<T> &&designator) {
 template <typename T>
 Constant<T> *Folder<T>::Folding(std::optional<ActualArgument> &arg) {
   if (auto *expr{UnwrapExpr<Expr<SomeType>>(arg)}) {
-    if (!UnwrapExpr<Expr<T>>(*expr)) {
-      if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) {
-        *expr = Fold(context_, std::move(*converted));
+    if constexpr (T::category != TypeCategory::Derived) {
+      if (!UnwrapExpr<Expr<T>>(*expr)) {
+        if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) {
+          *expr = Fold(context_, std::move(*converted));
+        }
       }
     }
     return UnwrapConstantValue<T>(*expr);
@@ -411,8 +414,6 @@ template <typename... A, std::size_t... I>
 std::optional<std::tuple<const Constant<A> *...>> GetConstantArgumentsHelper(
     FoldingContext &context, ActualArguments &arguments,
     std::index_sequence<I...>) {
-  static_assert(
-      (... && IsSpecificIntrinsicType<A>)); // TODO derived types for MERGE?
   static_assert(sizeof...(A) > 0);
   std::tuple<const Constant<A> *...> args{
       Folder<A>{context}.Folding(arguments.at(I))...};
@@ -489,7 +490,6 @@ Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
       }
     }
     CHECK(rank == GetRank(shape));
-
     // Compute all the scalar values of the results
     std::vector<Scalar<TR>> results;
     if (TotalElementCount(shape) > 0) {
@@ -513,6 +513,13 @@ Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
       auto len{static_cast<ConstantSubscript>(
           results.empty() ? 0 : results[0].length())};
       return Expr<TR>{Constant<TR>{len, std::move(results), std::move(shape)}};
+    } else if constexpr (TR::category == TypeCategory::Derived) {
+      if (!results.empty()) {
+        return Expr<TR>{rank == 0
+                ? Constant<TR>{results.front()}
+                : Constant<TR>{results.front().derivedTypeSpec(),
+                      std::move(results), std::move(shape)}};
+      }
     } else {
       return Expr<TR>{Constant<TR>{std::move(results), std::move(shape)}};
     }
@@ -780,6 +787,16 @@ template <typename T> Expr<T> Folder<T>::EOSHIFT(FunctionRef<T> &&funcRef) {
   return MakeInvalidIntrinsic(std::move(funcRef));
 }
 
+template <typename T> Expr<T> Folder<T>::MERGE(FunctionRef<T> &&funcRef) {
+  return FoldElementalIntrinsic<T, T, T, LogicalResult>(context_,
+      std::move(funcRef),
+      ScalarFunc<T, T, T, LogicalResult>(
+          [](const Scalar<T> &ifTrue, const Scalar<T> &ifFalse,
+              const Scalar<LogicalResult> &predicate) -> Scalar<T> {
+            return predicate.IsTrue() ? ifTrue : ifFalse;
+          }));
+}
+
 template <typename T> Expr<T> Folder<T>::PACK(FunctionRef<T> &&funcRef) {
   auto args{funcRef.arguments()};
   CHECK(args.size() == 3);
@@ -1126,6 +1143,8 @@ Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
       return Folder<T>{context}.CSHIFT(std::move(funcRef));
     } else if (name == "eoshift") {
       return Folder<T>{context}.EOSHIFT(std::move(funcRef));
+    } else if (name == "merge") {
+      return Folder<T>{context}.MERGE(std::move(funcRef));
     } else if (name == "pack") {
       return Folder<T>{context}.PACK(std::move(funcRef));
     } else if (name == "reshape") {
@@ -1147,17 +1166,6 @@ Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
   return Expr<T>{std::move(funcRef)};
 }
 
-template <typename T>
-Expr<T> FoldMerge(FoldingContext &context, FunctionRef<T> &&funcRef) {
-  return FoldElementalIntrinsic<T, T, T, LogicalResult>(context,
-      std::move(funcRef),
-      ScalarFunc<T, T, T, LogicalResult>(
-          [](const Scalar<T> &ifTrue, const Scalar<T> &ifFalse,
-              const Scalar<LogicalResult> &predicate) -> Scalar<T> {
-            return predicate.IsTrue() ? ifTrue : ifFalse;
-          }));
-}
-
 Expr<ImpliedDoIndex::Result> FoldOperation(FoldingContext &, ImpliedDoIndex &&);
 
 // Array constructor folding

diff  --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index 02d5ea5a133ada..53659d2c36d7c2 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -1038,8 +1038,6 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
   } else if (name == "maxval") {
     return FoldMaxvalMinval<T>(context, std::move(funcRef),
         RelationalOperator::GT, T::Scalar::Least());
-  } else if (name == "merge") {
-    return FoldMerge<T>(context, std::move(funcRef));
   } else if (name == "merge_bits") {
     return FoldElementalIntrinsic<T, T, T, T>(
         context, std::move(funcRef), &Scalar<T>::MERGE_BITS);

diff  --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp
index 129a8fc40577d0..0803c868368119 100644
--- a/flang/lib/Evaluate/fold-logical.cpp
+++ b/flang/lib/Evaluate/fold-logical.cpp
@@ -215,8 +215,6 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
     if (auto *expr{UnwrapExpr<Expr<SomeLogical>>(args[0])}) {
       return Fold(context, ConvertToType<T>(std::move(*expr)));
     }
-  } else if (name == "merge") {
-    return FoldMerge<T>(context, std::move(funcRef));
   } else if (name == "parity") {
     return FoldAllAnyParity(
         context, std::move(funcRef), &Scalar<T>::NEQV, Scalar<T>{false});

diff  --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 01a97951b0412a..671d897ef7b2f8 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -184,8 +184,6 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
   } else if (name == "maxval") {
     return FoldMaxvalMinval<T>(context, std::move(funcRef),
         RelationalOperator::GT, T::Scalar::HUGE().Negate());
-  } else if (name == "merge") {
-    return FoldMerge<T>(context, std::move(funcRef));
   } else if (name == "min") {
     return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
   } else if (name == "minval") {

diff  --git a/flang/test/Evaluate/fold-merge.f90 b/flang/test/Evaluate/fold-merge.f90
new file mode 100644
index 00000000000000..9cbd0ca7f2a99c
--- /dev/null
+++ b/flang/test/Evaluate/fold-merge.f90
@@ -0,0 +1,22 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of MERGE
+module m
+  type t
+    integer n
+  end type
+  logical, parameter :: test_01 = all(merge([1,2,3],4,[.true.,.false.,.true.]) == [1,4,3])
+  logical, parameter :: test_02 = all(merge([1,2,3],4,.true.) == [1,2,3])
+  logical, parameter :: test_03 = all(merge([1,2,3],4,.false.) == [4,4,4])
+  logical, parameter :: test_04 = all(merge(1,4,[.true.,.false.,.true.,.false.]) == [1,4,1,4])
+  type(t), parameter :: dt00a = merge(t(1),t(2),.true.)
+  logical, parameter :: test_05 = dt00a%n == 1
+  type(t), parameter :: dt00b = merge(t(1),t(2),.false.)
+  logical, parameter :: test_06 = dt00b%n == 2
+  type(t), parameter :: dt01(*) = merge([t(1),t(2)],[t(3),t(4)],[.false.,.true.])
+  logical, parameter :: test_07 = all(dt01%n == [3,2])
+  type(t), parameter :: dt02(*) = merge(t(1),[t(3),t(4)],.true.)
+  logical, parameter :: test_08 = all(dt02%n == [1,1])
+  type(t), parameter :: dt03(*) = merge([t(1),t(2)],t(3),[.true.,.false.])
+  logical, parameter :: test_09 = all(dt03%n == [1,3])
+  logical, parameter :: test_10 = merge('ab','cd',.true.) == 'ab'
+end


        


More information about the flang-commits mailing list