[flang-commits] [flang] [flang] Don't inject possibly invalid conversions while folding (PR #100842)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Fri Jul 26 17:15:56 PDT 2024


https://github.com/klausler created https://github.com/llvm/llvm-project/pull/100842

A couple of intrinsic functions have optional arguments.  Don't insert type conversions on those arguments when the actual arguments may not be present at execution time, due to being OPTIONAL, allocatables, or pointers.

>From 5c286580a28ffb4f2e8656bdd92c05bad8f0316c Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Fri, 26 Jul 2024 17:13:27 -0700
Subject: [PATCH] [flang] Don't inject possibly invalid conversions while
 folding

A couple of intrinsic functions have optional arguments.  Don't
insert type conversions on those arguments when the actual arguments
may not be present at execution time, due to being OPTIONAL,
allocatables, or pointers.
---
 flang/lib/Evaluate/fold-character.cpp    |  6 +--
 flang/lib/Evaluate/fold-implementation.h | 52 ++++++++++++++++--------
 flang/lib/Evaluate/fold-integer.cpp      | 12 ++++--
 flang/lib/Evaluate/fold-real.cpp         |  4 +-
 flang/test/Evaluate/rewrite08.f90        | 21 ++++++++++
 5 files changed, 68 insertions(+), 27 deletions(-)
 create mode 100644 flang/test/Evaluate/rewrite08.f90

diff --git a/flang/lib/Evaluate/fold-character.cpp b/flang/lib/Evaluate/fold-character.cpp
index 877bc2eac1fc2..5bdfa539eb0e0 100644
--- a/flang/lib/Evaluate/fold-character.cpp
+++ b/flang/lib/Evaluate/fold-character.cpp
@@ -97,7 +97,7 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
     return Expr<T>{Constant<T>{CharacterUtils<KIND>::NEW_LINE()}};
   } else if (name == "repeat") { // not elemental
     if (auto scalars{GetScalarConstantArguments<T, SubscriptInteger>(
-            context, funcRef.arguments())}) {
+            context, funcRef.arguments(), /*hasOptionalArgument=*/false)}) {
       auto str{std::get<Scalar<T>>(*scalars)};
       auto n{std::get<Scalar<SubscriptInteger>>(*scalars).ToInt64()};
       if (n < 0) {
@@ -117,8 +117,8 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
       }
     }
   } else if (name == "trim") { // not elemental
-    if (auto scalar{
-            GetScalarConstantArguments<T>(context, funcRef.arguments())}) {
+    if (auto scalar{GetScalarConstantArguments<T>(
+            context, funcRef.arguments(), /*hasOptionalArgument=*/false)}) {
       return Expr<T>{Constant<T>{
           CharacterUtils<KIND>::TRIM(std::get<Scalar<T>>(*scalar))}};
     }
diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index d5c393140c574..9ce0edbdcb779 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -54,7 +54,8 @@ static constexpr bool useKahanSummation{false};
 // Utilities
 template <typename T> class Folder {
 public:
-  explicit Folder(FoldingContext &c) : context_{c} {}
+  explicit Folder(FoldingContext &c, bool forOptionalArgument = false)
+      : context_{c}, forOptionalArgument_{forOptionalArgument} {}
   std::optional<Constant<T>> GetNamedConstant(const Symbol &);
   std::optional<Constant<T>> ApplySubscripts(const Constant<T> &array,
       const std::vector<Constant<SubscriptInteger>> &subscripts);
@@ -81,6 +82,7 @@ template <typename T> class Folder {
 
 private:
   FoldingContext &context_;
+  bool forOptionalArgument_{false};
 };
 
 std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
@@ -407,7 +409,14 @@ Constant<T> *Folder<T>::Folding(std::optional<ActualArgument> &arg) {
   if (auto *expr{UnwrapExpr<Expr<SomeType>>(arg)}) {
     if constexpr (T::category != TypeCategory::Derived) {
       if (!UnwrapExpr<Expr<T>>(*expr)) {
-        if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) {
+        if (const Symbol *
+                var{forOptionalArgument_
+                        ? UnwrapWholeSymbolOrComponentDataRef(*expr)
+                        : nullptr};
+            var && (IsOptional(*var) || IsAllocatableOrObjectPointer(var))) {
+          // can't safely convert item that may not be present
+        } else if (auto converted{
+                       ConvertToType(T::GetType(), std::move(*expr))}) {
           *expr = Fold(context_, std::move(*converted));
         }
       }
@@ -420,10 +429,10 @@ Constant<T> *Folder<T>::Folding(std::optional<ActualArgument> &arg) {
 template <typename... A, std::size_t... I>
 std::optional<std::tuple<const Constant<A> *...>> GetConstantArgumentsHelper(
     FoldingContext &context, ActualArguments &arguments,
-    std::index_sequence<I...>) {
+    bool hasOptionalArgument, std::index_sequence<I...>) {
   static_assert(sizeof...(A) > 0);
   std::tuple<const Constant<A> *...> args{
-      Folder<A>{context}.Folding(arguments.at(I))...};
+      Folder<A>{context, hasOptionalArgument}.Folding(arguments.at(I))...};
   if ((... && (std::get<I>(args)))) {
     return args;
   } else {
@@ -433,15 +442,17 @@ std::optional<std::tuple<const Constant<A> *...>> GetConstantArgumentsHelper(
 
 template <typename... A>
 std::optional<std::tuple<const Constant<A> *...>> GetConstantArguments(
-    FoldingContext &context, ActualArguments &args) {
+    FoldingContext &context, ActualArguments &args, bool hasOptionalArgument) {
   return GetConstantArgumentsHelper<A...>(
-      context, args, std::index_sequence_for<A...>{});
+      context, args, hasOptionalArgument, std::index_sequence_for<A...>{});
 }
 
 template <typename... A, std::size_t... I>
 std::optional<std::tuple<Scalar<A>...>> GetScalarConstantArgumentsHelper(
-    FoldingContext &context, ActualArguments &args, std::index_sequence<I...>) {
-  if (auto constArgs{GetConstantArguments<A...>(context, args)}) {
+    FoldingContext &context, ActualArguments &args, bool hasOptionalArgument,
+    std::index_sequence<I...>) {
+  if (auto constArgs{
+          GetConstantArguments<A...>(context, args, hasOptionalArgument)}) {
     return std::tuple<Scalar<A>...>{
         std::get<I>(*constArgs)->GetScalarValue().value()...};
   } else {
@@ -451,9 +462,9 @@ std::optional<std::tuple<Scalar<A>...>> GetScalarConstantArgumentsHelper(
 
 template <typename... A>
 std::optional<std::tuple<Scalar<A>...>> GetScalarConstantArguments(
-    FoldingContext &context, ActualArguments &args) {
+    FoldingContext &context, ActualArguments &args, bool hasOptionalArgument) {
   return GetScalarConstantArgumentsHelper<A...>(
-      context, args, std::index_sequence_for<A...>{});
+      context, args, hasOptionalArgument, std::index_sequence_for<A...>{});
 }
 
 // helpers to fold intrinsic function references
@@ -470,9 +481,10 @@ template <template <typename, typename...> typename WrapperType, typename TR,
     typename... TA, std::size_t... I>
 Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
     FunctionRef<TR> &&funcRef, WrapperType<TR, TA...> func,
-    std::index_sequence<I...>) {
+    bool hasOptionalArgument, std::index_sequence<I...>) {
   if (std::optional<std::tuple<const Constant<TA> *...>> args{
-          GetConstantArguments<TA...>(context, funcRef.arguments())}) {
+          GetConstantArguments<TA...>(
+              context, funcRef.arguments(), hasOptionalArgument)}) {
     // Compute the shape of the result based on shapes of arguments
     ConstantSubscripts shape;
     int rank{0};
@@ -542,15 +554,19 @@ Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
 
 template <typename TR, typename... TA>
 Expr<TR> FoldElementalIntrinsic(FoldingContext &context,
-    FunctionRef<TR> &&funcRef, ScalarFunc<TR, TA...> func) {
-  return FoldElementalIntrinsicHelper<ScalarFunc, TR, TA...>(
-      context, std::move(funcRef), func, std::index_sequence_for<TA...>{});
+    FunctionRef<TR> &&funcRef, ScalarFunc<TR, TA...> func,
+    bool hasOptionalArgument = false) {
+  return FoldElementalIntrinsicHelper<ScalarFunc, TR, TA...>(context,
+      std::move(funcRef), func, hasOptionalArgument,
+      std::index_sequence_for<TA...>{});
 }
 template <typename TR, typename... TA>
 Expr<TR> FoldElementalIntrinsic(FoldingContext &context,
-    FunctionRef<TR> &&funcRef, ScalarFuncWithContext<TR, TA...> func) {
-  return FoldElementalIntrinsicHelper<ScalarFuncWithContext, TR, TA...>(
-      context, std::move(funcRef), func, std::index_sequence_for<TA...>{});
+    FunctionRef<TR> &&funcRef, ScalarFuncWithContext<TR, TA...> func,
+    bool hasOptionalArgument = false) {
+  return FoldElementalIntrinsicHelper<ScalarFuncWithContext, TR, TA...>(context,
+      std::move(funcRef), func, hasOptionalArgument,
+      std::index_sequence_for<TA...>{});
 }
 
 std::optional<std::int64_t> GetInt64ArgOr(
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index 981cdff7f350b..39de171dd7a8b 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -347,7 +347,8 @@ template <WhichLocation WHICH> class LocationHelper {
     bool back{false};
     if (arg_[backArg]) {
       const auto *backConst{
-          Folder<LogicalResult>{context_}.Folding(arg_[backArg])};
+          Folder<LogicalResult>{context_, /*forOptionalArgument=*/true}.Folding(
+              arg_[backArg])};
       if (backConst) {
         back = backConst->GetScalarValue().value().IsTrue();
       } else {
@@ -910,8 +911,10 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
     const auto *argCon{Folder<T>(context).Folding(args[0])};
     const auto *shiftCon{Folder<Int4>(context).Folding(args[1])};
     const auto *shiftVals{shiftCon ? &shiftCon->values() : nullptr};
-    const auto *sizeCon{
-        args.size() == 3 ? Folder<Int4>(context).Folding(args[2]) : nullptr};
+    const auto *sizeCon{args.size() == 3
+            ? Folder<Int4>{context, /*forOptionalArgument=*/true}.Folding(
+                  args[2])
+            : nullptr};
     const auto *sizeVals{sizeCon ? &sizeCon->values() : nullptr};
     if ((argCon && argCon->empty()) || !shiftVals || shiftVals->empty() ||
         (sizeVals && sizeVals->empty())) {
@@ -985,7 +988,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
                 auto shiftVal{static_cast<int>(shift.ToInt64())};
                 auto sizeVal{static_cast<int>(size.ToInt64())};
                 return i.ISHFTC(shiftVal, sizeVal);
-              }));
+              }),
+          /*hasOptionalArgument=*/true);
     }
   } else if (name == "izext" || name == "jzext") {
     if (args.size() == 1) {
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 69c7a924cc1c3..2dd08a7d1a6e1 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -20,8 +20,8 @@ static Expr<T> FoldTransformationalBessel(
   /// arguments to Int4, any overflow error will be reported during the
   /// conversion folding.
   using Int4 = Type<TypeCategory::Integer, 4>;
-  if (auto args{
-          GetConstantArguments<Int4, Int4, T>(context, funcRef.arguments())}) {
+  if (auto args{GetConstantArguments<Int4, Int4, T>(
+          context, funcRef.arguments(), /*hasOptionalArgument=*/false)}) {
     const std::string &name{std::get<SpecificIntrinsic>(funcRef.proc().u).name};
     if (auto elementalBessel{GetHostRuntimeWrapper<T, Int4, T>(name)}) {
       std::vector<Scalar<T>> results;
diff --git a/flang/test/Evaluate/rewrite08.f90 b/flang/test/Evaluate/rewrite08.f90
new file mode 100644
index 0000000000000..c59605581d63c
--- /dev/null
+++ b/flang/test/Evaluate/rewrite08.f90
@@ -0,0 +1,21 @@
+! RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
+subroutine s(oi,ol)
+  integer(1), optional, intent(in) :: oi
+  logical(1), optional, intent(in) :: ol
+  integer(1), allocatable :: ai
+  logical(1), allocatable :: al
+  integer(1), pointer :: pi
+  logical(1), pointer :: pl
+!CHECK: PRINT *, ishftc(-1_4,1_4,oi)
+!CHECK: PRINT *, ishftc(-1_4,1_4,ai)
+!CHECK: PRINT *, ishftc(-1_4,1_4,pi)
+!CHECK: PRINT *, findloc([INTEGER(4)::1_4,2_4,1_4],1_4,back=ol)
+!CHECK: PRINT *, findloc([INTEGER(4)::1_4,2_4,1_4],1_4,back=al)
+!CHECK: PRINT *, findloc([INTEGER(4)::1_4,2_4,1_4],1_4,back=pl)
+  print *, ishftc(-1,1,oi)
+  print *, ishftc(-1,1,ai)
+  print *, ishftc(-1,1,pi)
+  print *, findloc([1,2,1],1,back=ol)
+  print *, findloc([1,2,1],1,back=al)
+  print *, findloc([1,2,1],1,back=pl)
+end



More information about the flang-commits mailing list