[flang-commits] [flang] 488b9fd - [flang] Do not ICE on recursive function definition in function result

Daniil Dudkin via flang-commits flang-commits at lists.llvm.org
Thu Apr 21 03:16:08 PDT 2022


Author: Daniil Dudkin
Date: 2022-04-21T19:04:17+09:00
New Revision: 488b9fd1030b1e75a6c3580d0a632009315e31f5

URL: https://github.com/llvm/llvm-project/commit/488b9fd1030b1e75a6c3580d0a632009315e31f5
DIFF: https://github.com/llvm/llvm-project/commit/488b9fd1030b1e75a6c3580d0a632009315e31f5.diff

LOG: [flang] Do not ICE on recursive function definition in function result

The following code causes the compiler to ICE in several places due to
lack of support of recursive procedure definitions through the function
result.

  function foo() result(r)
    procedure(foo), pointer :: r
  end function foo

Added: 
    

Modified: 
    flang/include/flang/Semantics/symbol.h
    flang/lib/Evaluate/characteristics.cpp
    flang/lib/Semantics/check-declarations.cpp
    flang/test/Semantics/resolve102.f90

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index f836da4b3da0e..d10c1fe9c12e6 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -15,6 +15,7 @@
 #include "flang/Common/reference.h"
 #include "flang/Common/visit.h"
 #include "llvm/ADT/DenseMapInfo.h"
+
 #include <array>
 #include <functional>
 #include <list>
@@ -637,38 +638,7 @@ class Symbol {
   bool operator==(const Symbol &that) const { return this == &that; }
   bool operator!=(const Symbol &that) const { return !(*this == that); }
 
-  int Rank() const {
-    return common::visit(
-        common::visitors{
-            [](const SubprogramDetails &sd) {
-              return sd.isFunction() ? sd.result().Rank() : 0;
-            },
-            [](const GenericDetails &) {
-              return 0; /*TODO*/
-            },
-            [](const ProcBindingDetails &x) { return x.symbol().Rank(); },
-            [](const UseDetails &x) { return x.symbol().Rank(); },
-            [](const HostAssocDetails &x) { return x.symbol().Rank(); },
-            [](const ObjectEntityDetails &oed) { return oed.shape().Rank(); },
-            [](const ProcEntityDetails &ped) {
-              const Symbol *iface{ped.interface().symbol()};
-              return iface ? iface->Rank() : 0;
-            },
-            [](const AssocEntityDetails &aed) {
-              if (const auto &expr{aed.expr()}) {
-                if (auto assocRank{aed.rank()}) {
-                  return *assocRank;
-                } else {
-                  return expr->Rank();
-                }
-              } else {
-                return 0;
-              }
-            },
-            [](const auto &) { return 0; },
-        },
-        details_);
-  }
+  int Rank() const { return RankImpl(); }
 
   int Corank() const {
     return common::visit(
@@ -718,6 +688,48 @@ class Symbol {
   friend llvm::raw_ostream &DumpForUnparse(
       llvm::raw_ostream &, const Symbol &, bool);
 
+  static constexpr int startRecursionDepth{100};
+
+  inline const DeclTypeSpec *GetTypeImpl(int depth = startRecursionDepth) const;
+  inline int RankImpl(int depth = startRecursionDepth) const {
+    if (depth-- == 0) {
+      return 0;
+    }
+    return common::visit(
+        common::visitors{
+            [&](const SubprogramDetails &sd) {
+              return sd.isFunction() ? sd.result().RankImpl(depth) : 0;
+            },
+            [](const GenericDetails &) {
+              return 0; /*TODO*/
+            },
+            [&](const ProcBindingDetails &x) {
+              return x.symbol().RankImpl(depth);
+            },
+            [&](const UseDetails &x) { return x.symbol().RankImpl(depth); },
+            [&](const HostAssocDetails &x) {
+              return x.symbol().RankImpl(depth);
+            },
+            [](const ObjectEntityDetails &oed) { return oed.shape().Rank(); },
+            [&](const ProcEntityDetails &ped) {
+              const Symbol *iface{ped.interface().symbol()};
+              return iface ? iface->RankImpl(depth) : 0;
+            },
+            [](const AssocEntityDetails &aed) {
+              if (const auto &expr{aed.expr()}) {
+                if (auto assocRank{aed.rank()}) {
+                  return *assocRank;
+                } else {
+                  return expr->Rank();
+                }
+              } else {
+                return 0;
+              }
+            },
+            [](const auto &) { return 0; },
+        },
+        details_);
+  }
   template <std::size_t> friend class Symbols;
   template <class, std::size_t> friend class std::array;
 };
@@ -786,28 +798,38 @@ inline DeclTypeSpec *Symbol::GetType() {
   return const_cast<DeclTypeSpec *>(
       const_cast<const Symbol *>(this)->GetType());
 }
-inline const DeclTypeSpec *Symbol::GetType() const {
+
+inline const DeclTypeSpec *Symbol::GetTypeImpl(int depth) const {
+  if (depth-- == 0) {
+    return nullptr;
+  }
   return common::visit(
       common::visitors{
           [](const EntityDetails &x) { return x.type(); },
           [](const ObjectEntityDetails &x) { return x.type(); },
           [](const AssocEntityDetails &x) { return x.type(); },
-          [](const SubprogramDetails &x) {
-            return x.isFunction() ? x.result().GetType() : nullptr;
+          [&](const SubprogramDetails &x) {
+            return x.isFunction() ? x.result().GetTypeImpl(depth) : nullptr;
           },
-          [](const ProcEntityDetails &x) {
+          [&](const ProcEntityDetails &x) {
             const Symbol *symbol{x.interface().symbol()};
-            return symbol ? symbol->GetType() : x.interface().type();
+            return symbol ? symbol->GetTypeImpl(depth) : x.interface().type();
+          },
+          [&](const ProcBindingDetails &x) {
+            return x.symbol().GetTypeImpl(depth);
           },
-          [](const ProcBindingDetails &x) { return x.symbol().GetType(); },
           [](const TypeParamDetails &x) { return x.type(); },
-          [](const UseDetails &x) { return x.symbol().GetType(); },
-          [](const HostAssocDetails &x) { return x.symbol().GetType(); },
+          [&](const UseDetails &x) { return x.symbol().GetTypeImpl(depth); },
+          [&](const HostAssocDetails &x) {
+            return x.symbol().GetTypeImpl(depth);
+          },
           [](const auto &) -> const DeclTypeSpec * { return nullptr; },
       },
       details_);
 }
 
+inline const DeclTypeSpec *Symbol::GetType() const { return GetTypeImpl(); }
+
 // Sets and maps keyed by Symbols
 
 struct SymbolAddressCompare {

diff  --git a/flang/lib/Evaluate/characteristics.cpp b/flang/lib/Evaluate/characteristics.cpp
index 474c6080a0861..cac34df644b43 100644
--- a/flang/lib/Evaluate/characteristics.cpp
+++ b/flang/lib/Evaluate/characteristics.cpp
@@ -72,7 +72,7 @@ std::optional<TypeAndShape> TypeAndShape::Characterize(
   return common::visit(
       common::visitors{
           [&](const semantics::ProcEntityDetails &proc) {
-            const semantics::ProcInterface &interface{proc.interface()};
+            const semantics::ProcInterface &interface { proc.interface() };
             if (interface.type()) {
               return Characterize(*interface.type(), context);
             } else if (interface.symbol()) {
@@ -367,6 +367,9 @@ static std::string GetSeenProcs(
 static std::optional<DummyArgument> CharacterizeDummyArgument(
     const semantics::Symbol &symbol, FoldingContext &context,
     semantics::UnorderedSymbolSet seenProcs);
+static std::optional<FunctionResult> CharacterizeFunctionResult(
+    const semantics::Symbol &symbol, FoldingContext &context,
+    semantics::UnorderedSymbolSet seenProcs);
 
 static std::optional<Procedure> CharacterizeProcedure(
     const semantics::Symbol &original, FoldingContext &context,
@@ -397,8 +400,8 @@ static std::optional<Procedure> CharacterizeProcedure(
           [&](const semantics::SubprogramDetails &subp)
               -> std::optional<Procedure> {
             if (subp.isFunction()) {
-              if (auto fr{
-                      FunctionResult::Characterize(subp.result(), context)}) {
+              if (auto fr{CharacterizeFunctionResult(
+                      subp.result(), context, seenProcs)}) {
                 result.functionResult = std::move(fr);
               } else {
                 return std::nullopt;
@@ -438,7 +441,7 @@ static std::optional<Procedure> CharacterizeProcedure(
               }
               return intrinsic;
             }
-            const semantics::ProcInterface &interface{proc.interface()};
+            const semantics::ProcInterface &interface { proc.interface() };
             if (const semantics::Symbol * interfaceSymbol{interface.symbol()}) {
               return CharacterizeProcedure(
                   *interfaceSymbol, context, seenProcs);
@@ -699,8 +702,9 @@ bool FunctionResult::operator==(const FunctionResult &that) const {
   return attrs == that.attrs && u == that.u;
 }
 
-std::optional<FunctionResult> FunctionResult::Characterize(
-    const Symbol &symbol, FoldingContext &context) {
+static std::optional<FunctionResult> CharacterizeFunctionResult(
+    const semantics::Symbol &symbol, FoldingContext &context,
+    semantics::UnorderedSymbolSet seenProcs) {
   if (symbol.has<semantics::ObjectEntityDetails>()) {
     if (auto type{TypeAndShape::Characterize(symbol, context)}) {
       FunctionResult result{std::move(*type)};
@@ -712,7 +716,8 @@ std::optional<FunctionResult> FunctionResult::Characterize(
           });
       return result;
     }
-  } else if (auto maybeProc{Procedure::Characterize(symbol, context)}) {
+  } else if (auto maybeProc{
+                 CharacterizeProcedure(symbol, context, seenProcs)}) {
     FunctionResult result{std::move(*maybeProc)};
     result.attrs.set(FunctionResult::Attr::Pointer);
     return result;
@@ -720,6 +725,12 @@ std::optional<FunctionResult> FunctionResult::Characterize(
   return std::nullopt;
 }
 
+std::optional<FunctionResult> FunctionResult::Characterize(
+    const Symbol &symbol, FoldingContext &context) {
+  semantics::UnorderedSymbolSet seenProcs;
+  return CharacterizeFunctionResult(symbol, context, seenProcs);
+}
+
 bool FunctionResult::IsAssumedLengthCharacter() const {
   if (const auto *ts{std::get_if<TypeAndShape>(&u)}) {
     return ts->type().IsAssumedLengthCharacter();

diff  --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp
index ad999ed91a199..7ee267bd2a876 100644
--- a/flang/lib/Semantics/check-declarations.cpp
+++ b/flang/lib/Semantics/check-declarations.cpp
@@ -1786,6 +1786,9 @@ void CheckHelper::CheckGenericOps(const Scope &scope) {
   auto addSpecifics{[&](const Symbol &generic) {
     const auto *details{generic.GetUltimate().detailsIf<GenericDetails>()};
     if (!details) {
+      if (generic.test(Symbol::Flag::Function)) {
+        Characterize(generic);
+      }
       return;
     }
     GenericKind kind{details->kind()};

diff  --git a/flang/test/Semantics/resolve102.f90 b/flang/test/Semantics/resolve102.f90
index 07f12d7b76839..ef6a477032a1d 100644
--- a/flang/test/Semantics/resolve102.f90
+++ b/flang/test/Semantics/resolve102.f90
@@ -20,6 +20,12 @@ subroutine sub(p2)
     end subroutine
 end subroutine circular
 
+!ERROR: Procedure 'foo' is recursively defined.  Procedures in the cycle: 'foo', 'r'
+function foo() result(r)
+  !ERROR: Procedure 'r' is recursively defined.  Procedures in the cycle: 'foo', 'r'
+  procedure(foo), pointer :: r 
+end function foo
+
 program iface
   !ERROR: Procedure 'p' is recursively defined.  Procedures in the cycle: 'p', 'sub', 'p2'
   procedure(sub) :: p


        


More information about the flang-commits mailing list