[flang-commits] [flang] [flang] Defer processing of non-pointer variable initializers (PR #76475)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Mon Jan 1 12:01:49 PST 2024


https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/76475

>From 087bd8dbdb806ba21c8b6230eb3b06974c7eae50 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Wed, 27 Dec 2023 14:45:35 -0800
Subject: [PATCH] [flang] Defer processing of non-pointer variable initializers

Initializers in entity-decls don't need to have their expressions
analyzed immediately in name resolution unless of course they are
defining the values of named constants.  By deferring the expression
analysis, the compiler can better handle references to module and
internal procedures that might appear in structure constructors;
at present, these are typically rejected as being forward references
(which they can be) to subprogram names that can't yet be checked
for compatibility with the characteristics of the corresponding
procedure component.
---
 flang/lib/Semantics/resolve-names.cpp     | 105 +++++++++++++++-------
 flang/test/Semantics/bad-forward-type.f90 |   1 -
 flang/test/Semantics/init01.f90           |   2 +-
 flang/test/Semantics/pointer01.f90        |   1 +
 flang/test/Semantics/symbol15.f90         |  24 ++---
 5 files changed, 87 insertions(+), 46 deletions(-)

diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index f5f7b99aba2551..e30eb5070d789a 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -698,7 +698,7 @@ class ScopeHandler : public ImplicitRulesVisitor {
   bool CheckPossibleBadForwardRef(const Symbol &);
 
   bool inSpecificationPart_{false};
-  bool inDataStmtObject_{false};
+  bool deferImplicitTyping_{false};
   bool inEquivalenceStmt_{false};
 
   // Some information is collected from a specification part for deferred
@@ -1629,6 +1629,7 @@ class ResolveNamesVisitor : public virtual ScopeHandler,
   bool BeginScopeForNode(const ProgramTree &);
   void EndScopeForNode(const ProgramTree &);
   void FinishSpecificationParts(const ProgramTree &);
+  void FinishExecutionParts(const ProgramTree &);
   void FinishDerivedTypeInstantiation(Scope &);
   void ResolveExecutionParts(const ProgramTree &);
   void UseCUDABuiltinNames();
@@ -2533,7 +2534,7 @@ void ScopeHandler::ApplyImplicitRules(
     // or object, it'll be caught later.
     return;
   }
-  if (inDataStmtObject_) {
+  if (deferImplicitTyping_) {
     return;
   }
   if (!context().HasError(symbol)) {
@@ -2709,7 +2710,7 @@ const DeclTypeSpec &ScopeHandler::MakeLogicalType(int kind) {
 }
 
 void ScopeHandler::NotePossibleBadForwardRef(const parser::Name &name) {
-  if (inSpecificationPart_ && !inDataStmtObject_ && name.symbol) {
+  if (inSpecificationPart_ && !deferImplicitTyping_ && name.symbol) {
     auto kind{currScope().kind()};
     if ((kind == Scope::Kind::Subprogram && !currScope().IsStmtFunction()) ||
         kind == Scope::Kind::BlockConstruct) {
@@ -6802,7 +6803,8 @@ bool ConstructVisitor::Pre(const parser::DataStmtObject &x) {
   auto flagRestorer{common::ScopedSet(inSpecificationPart_, false)};
   common::visit(common::visitors{
                     [&](const Indirection<parser::Variable> &y) {
-                      auto restorer{common::ScopedSet(inDataStmtObject_, true)};
+                      auto restorer{
+                          common::ScopedSet(deferImplicitTyping_, true)};
                       Walk(y.value());
                       const parser::Name &first{
                           parser::GetFirstName(y.value())};
@@ -7386,7 +7388,7 @@ const parser::Name *DeclarationVisitor::ResolveName(const parser::Name &name) {
     }
     return &name;
   }
-  if (isImplicitNoneType() && !inDataStmtObject_) {
+  if (isImplicitNoneType() && !deferImplicitTyping_) {
     Say(name, "No explicit type declared for '%s'"_err_en_US);
     return nullptr;
   }
@@ -7548,7 +7550,15 @@ void DeclarationVisitor::Initialization(const parser::Name &name,
   common::visit(
       common::visitors{
           [&](const parser::ConstantExpr &expr) {
-            NonPointerInitialization(name, expr);
+            Walk(expr);
+            if (IsNamedConstant(ultimate) || inComponentDecl) {
+              NonPointerInitialization(name, expr);
+            } else {
+              // Defer analysis so forward references to nested subprograms
+              // can be properly resolved when they appear in structure
+              // constructors.
+              ultimate.set(Symbol::Flag::InDataStmt);
+            }
           },
           [&](const parser::NullInit &null) { // => NULL()
             Walk(null);
@@ -7569,10 +7579,12 @@ void DeclarationVisitor::Initialization(const parser::Name &name,
               }
             }
           },
-          [&](const parser::InitialDataTarget &) {
+          [&](const parser::InitialDataTarget &target) {
             // Defer analysis to the end of the specification part
             // so that forward references and attribute checks like SAVE
             // work better.
+            auto restorer{common::ScopedSet(deferImplicitTyping_, true)};
+            Walk(target);
             ultimate.set(Symbol::Flag::InDataStmt);
           },
           [&](const std::list<Indirection<parser::DataStmtValue>> &values) {
@@ -7590,12 +7602,27 @@ void DeclarationVisitor::PointerInitialization(
     Symbol &ultimate{name.symbol->GetUltimate()};
     if (!context().HasError(ultimate)) {
       if (IsPointer(ultimate)) {
-        if (auto *details{ultimate.detailsIf<ObjectEntityDetails>()}) {
-          CHECK(!details->init());
-          Walk(target);
-          if (MaybeExpr expr{EvaluateExpr(target)}) {
-            // Validation is done in declaration checking.
+        Walk(target);
+        if (MaybeExpr expr{EvaluateExpr(target)}) {
+          // Validation is done in declaration checking.
+          if (auto *details{ultimate.detailsIf<ObjectEntityDetails>()}) {
+            CHECK(!details->init());
             details->set_init(std::move(*expr));
+            ultimate.set(Symbol::Flag::InDataStmt, false);
+          } else if (auto *details{ultimate.detailsIf<ProcEntityDetails>()}) {
+            // something like "REAL, EXTERNAL, POINTER :: p => t"
+            if (evaluate::IsNullProcedurePointer(*expr)) {
+              CHECK(!details->init());
+              details->set_init(nullptr);
+            } else if (const Symbol *
+                targetSymbol{evaluate::UnwrapWholeSymbolDataRef(*expr)}) {
+              CHECK(!details->init());
+              details->set_init(*targetSymbol);
+            } else {
+              Say(name,
+                  "Procedure pointer '%s' must be initialized with a procedure name or NULL()"_err_en_US);
+              context().SetError(ultimate);
+            }
           }
         }
       } else {
@@ -7635,27 +7662,23 @@ void DeclarationVisitor::PointerInitialization(
 
 void DeclarationVisitor::NonPointerInitialization(
     const parser::Name &name, const parser::ConstantExpr &expr) {
-  if (name.symbol) {
+  if (!context().HasError(name.symbol)) {
     Symbol &ultimate{name.symbol->GetUltimate()};
-    if (!context().HasError(ultimate) && !context().HasError(name.symbol)) {
+    if (!context().HasError(ultimate)) {
       if (IsPointer(ultimate)) {
         Say(name,
             "'%s' is a pointer but is not initialized like one"_err_en_US);
       } else if (auto *details{ultimate.detailsIf<ObjectEntityDetails>()}) {
-        CHECK(!details->init());
-        if (IsAllocatable(ultimate)) {
+        if (details->init()) {
+        } else if (IsAllocatable(ultimate)) {
           Say(name, "Allocatable object '%s' cannot be initialized"_err_en_US);
-          return;
-        }
-        Walk(expr);
-        if (ultimate.owner().IsParameterizedDerivedType()) {
+        } else if (ultimate.owner().IsParameterizedDerivedType()) {
           // Save the expression for per-instantiation analysis.
           details->set_unanalyzedPDTComponentInit(&expr.thing.value());
-        } else {
-          if (MaybeExpr folded{EvaluateNonPointerInitializer(
-                  ultimate, expr, expr.thing.value().source)}) {
-            details->set_init(std::move(*folded));
-          }
+        } else if (MaybeExpr folded{EvaluateNonPointerInitializer(
+                       ultimate, expr, expr.thing.value().source)}) {
+          details->set_init(std::move(*folded));
+          ultimate.set(Symbol::Flag::InDataStmt, false);
         }
       } else {
         Say(name, "'%s' is not an object that can be initialized"_err_en_US);
@@ -8424,6 +8447,7 @@ bool ResolveNamesVisitor::Pre(const parser::ProgramUnit &x) {
   ResolveSpecificationParts(root);
   FinishSpecificationParts(root);
   ResolveExecutionParts(root);
+  FinishExecutionParts(root);
   ResolveAccParts(context(), x);
   ResolveOmpParts(context(), x);
   return false;
@@ -8841,6 +8865,8 @@ class DeferredCheckVisitor {
     }
   }
 
+  bool Pre(const parser::BlockConstruct &x) { return true; }
+
   void Post(const parser::ProcInterface &pi) {
     if (const auto *name{std::get_if<parser::Name>(&pi.u)}) {
       resolver_.CheckExplicitInterface(*name);
@@ -8871,7 +8897,6 @@ class DeferredCheckVisitor {
       resolver_.CheckBindings(tbps);
     }
   }
-  bool Pre(const parser::StmtFunctionStmt &stmtFunc) { return false; }
 
 private:
   void Init(const parser::Name &name,
@@ -8880,6 +8905,9 @@ class DeferredCheckVisitor {
       if (const auto *target{
               std::get_if<parser::InitialDataTarget>(&init->u)}) {
         resolver_.PointerInitialization(name, *target);
+      } else if (const auto *expr{
+                     std::get_if<parser::ConstantExpr>(&init->u)}) {
+        resolver_.NonPointerInitialization(name, *expr);
       }
     }
   }
@@ -8894,15 +8922,16 @@ void ResolveNamesVisitor::FinishSpecificationParts(const ProgramTree &node) {
   if (!node.scope()) {
     return; // error occurred creating scope
   }
+  auto flagRestorer{common::ScopedSet(inSpecificationPart_, true)};
   SetScope(*node.scope());
-  // The initializers of pointers, the default initializers of pointer
-  // components, non-deferred type-bound procedure bindings have not
-  // yet been traversed.
-  // We do that now, when any (formerly) forward references that appear
+  // The initializers of pointers and non-PARAMETER objects, the default
+  // initializers of components, and non-deferred type-bound procedure
+  // bindings have not yet been traversed.
+  // We do that now, when any forward references that appeared
   // in those initializers will resolve to the right symbols without
-  // incurring spurious errors with IMPLICIT NONE.
+  // incurring spurious errors with IMPLICIT NONE or forward references
+  // to nested subprograms.
   DeferredCheckVisitor{*this}.Walk(node.spec());
-  DeferredCheckVisitor{*this}.Walk(node.exec()); // for BLOCK
   for (Scope &childScope : currScope().children()) {
     if (childScope.IsParameterizedDerivedTypeInstantiation()) {
       FinishDerivedTypeInstantiation(childScope);
@@ -8913,6 +8942,18 @@ void ResolveNamesVisitor::FinishSpecificationParts(const ProgramTree &node) {
   }
 }
 
+void ResolveNamesVisitor::FinishExecutionParts(const ProgramTree &node) {
+  if (node.scope()) {
+    SetScope(*node.scope());
+    if (node.exec()) {
+      DeferredCheckVisitor{*this}.Walk(*node.exec());
+    }
+    for (const auto &child : node.children()) {
+      FinishExecutionParts(child);
+    }
+  }
+}
+
 // Duplicate and fold component object pointer default initializer designators
 // using the actual type parameter values of each particular instantiation.
 // Validation is done later in declaration checking.
diff --git a/flang/test/Semantics/bad-forward-type.f90 b/flang/test/Semantics/bad-forward-type.f90
index 19e23e654642fd..432d450a15f3fc 100644
--- a/flang/test/Semantics/bad-forward-type.f90
+++ b/flang/test/Semantics/bad-forward-type.f90
@@ -84,7 +84,6 @@ subroutine s9
   type con
     Type(t(3)), pointer :: y
   end type
-  !ERROR: Cannot construct value for derived type 't' before it is defined
   Integer :: nn = Size(Transfer(t(3)(666),[0]))
   type :: t(n)
     integer, kind :: n = 3
diff --git a/flang/test/Semantics/init01.f90 b/flang/test/Semantics/init01.f90
index 9f75a8d5567330..0f5a2144c79f97 100644
--- a/flang/test/Semantics/init01.f90
+++ b/flang/test/Semantics/init01.f90
@@ -90,7 +90,7 @@ subroutine components(n)
   real, pointer :: p10 => o3%x
   associate (a1 => o3, a2 => o3%x)
     block
-      real, pointer :: p11 => a1
+      type(t3), pointer :: p11 => a1
       real, pointer :: p12 => a2
     end block
   end associate
diff --git a/flang/test/Semantics/pointer01.f90 b/flang/test/Semantics/pointer01.f90
index cb860f3a3f437c..9e87d1b689eb2b 100644
--- a/flang/test/Semantics/pointer01.f90
+++ b/flang/test/Semantics/pointer01.f90
@@ -16,6 +16,7 @@ program main
   !ERROR: 'inner' cannot have the POINTER attribute
   pointer inner
   real obj
+  !ERROR: 'ip' is a pointer but is not initialized like one
   !ERROR: 'ip' may not have both the POINTER and PARAMETER attributes
   integer, parameter :: ip = 123
   pointer ip
diff --git a/flang/test/Semantics/symbol15.f90 b/flang/test/Semantics/symbol15.f90
index 318819e224cd8c..97dc50a23845f4 100644
--- a/flang/test/Semantics/symbol15.f90
+++ b/flang/test/Semantics/symbol15.f90
@@ -14,10 +14,10 @@ subroutine iface
  !DEF: /m/op2 POINTER, PUBLIC ObjectEntity REAL(4)
  !DEF: /m/null INTRINSIC, PUBLIC, PURE (Function) ProcEntity
  real, pointer :: op2 => null()
- !DEF: /m/op3 POINTER, PUBLIC (InDataStmt) ObjectEntity REAL(4)
+ !DEF: /m/op3 POINTER, PUBLIC ObjectEntity REAL(4)
  !DEF: /m/x PUBLIC, TARGET ObjectEntity REAL(4)
  real, pointer :: op3 => x
- !DEF: /m/op4 POINTER, PUBLIC (InDataStmt) ObjectEntity REAL(4)
+ !DEF: /m/op4 POINTER, PUBLIC ObjectEntity REAL(4)
  !DEF: /m/y PUBLIC, TARGET ObjectEntity REAL(4)
  real, pointer :: op4 => y(1)
  !REF: /m/iface
@@ -50,10 +50,10 @@ subroutine iface
   !DEF: /m/t1/opc2 POINTER ObjectEntity REAL(4)
   !REF: /m/null
   real, pointer :: opc2 => null()
-  !DEF: /m/t1/opc3 POINTER (InDataStmt) ObjectEntity REAL(4)
+  !DEF: /m/t1/opc3 POINTER ObjectEntity REAL(4)
   !REF: /m/x
   real, pointer :: opc3 => x
-  !DEF: /m/t1/opc4 POINTER (InDataStmt) ObjectEntity REAL(4)
+  !DEF: /m/t1/opc4 POINTER ObjectEntity REAL(4)
   !REF: /m/y
   real, pointer :: opc4 => y(1)
   !REF: /m/iface
@@ -100,10 +100,10 @@ subroutine iface
   !DEF: /m/pdt1/opc2 POINTER ObjectEntity REAL(4)
   !REF: /m/null
   real, pointer :: opc2 => null()
-  !DEF: /m/pdt1/opc3 POINTER (InDataStmt) ObjectEntity REAL(4)
+  !DEF: /m/pdt1/opc3 POINTER ObjectEntity REAL(4)
   !REF: /m/x
   real, pointer :: opc3 => x
-  !DEF: /m/pdt1/opc4 POINTER (InDataStmt) ObjectEntity REAL(4)
+  !DEF: /m/pdt1/opc4 POINTER ObjectEntity REAL(4)
   !REF: /m/y
   !REF: /m/pdt1/k
   real, pointer :: opc4 => y(k)
@@ -160,10 +160,10 @@ subroutine iface
   subroutine ext2
   end subroutine
  end interface
- !DEF: /m/op10 POINTER, PUBLIC(InDataStmt) ObjectEntity REAL(4)
+ !DEF: /m/op10 POINTER, PUBLIC ObjectEntity REAL(4)
  !REF: /m/x
  real, pointer :: op10 => x
- !DEF: /m/op11 POINTER, PUBLIC(InDataStmt) ObjectEntity REAL(4)
+ !DEF: /m/op11 POINTER, PUBLIC ObjectEntity REAL(4)
  !REF: /m/y
  real, pointer :: op11 => y(1)
  !REF: /m/iface
@@ -176,10 +176,10 @@ subroutine ext2
  procedure(iface), pointer :: pp11 => ext2
  !DEF: /m/t2 PUBLIC DerivedType
  type :: t2
-  !DEF: /m/t2/opc10 POINTER (InDataStmt) ObjectEntity REAL(4)
+  !DEF: /m/t2/opc10 POINTER ObjectEntity REAL(4)
   !REF: /m/x
   real, pointer :: opc10 => x
-  !DEF: /m/t2/opc11 POINTER (InDataStmt) ObjectEntity REAL(4)
+  !DEF: /m/t2/opc11 POINTER ObjectEntity REAL(4)
   !REF: /m/y
   real, pointer :: opc11 => y(1)
   !REF: /m/iface
@@ -203,10 +203,10 @@ subroutine ext2
  type :: pdt2(k)
   !REF: /m/pdt2/k
   integer, kind :: k
-  !DEF: /m/pdt2/opc10 POINTER (InDataStmt) ObjectEntity REAL(4)
+  !DEF: /m/pdt2/opc10 POINTER ObjectEntity REAL(4)
   !REF: /m/x
   real, pointer :: opc10 => x
-  !DEF: /m/pdt2/opc11 POINTER (InDataStmt) ObjectEntity REAL(4)
+  !DEF: /m/pdt2/opc11 POINTER ObjectEntity REAL(4)
   !REF: /m/y
   !REF: /m/pdt2/k
   real, pointer :: opc11 => y(k)



More information about the flang-commits mailing list