[flang-commits] [flang] [flang] Get base objects right in definability checker (PR #78854)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Tue Jan 23 10:14:41 PST 2024


https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/78854

>From 4d6ca3600825200a1ad5e6cd48e2f667dbf0f614 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Sat, 20 Jan 2024 09:45:14 -0800
Subject: [PATCH] [flang] Get base objects right in definability checker

The utility function GetRelevantObject() seems to be just wrong
for definability checks for the "base object" of a designator,
and that's all for which it is (now?) used.  This leads to some
false error messages in Whizard when data-refs with multiple
pointer components are defined.  Simplify, and add more test
cases.
---
 flang/lib/Semantics/definable.cpp    | 82 ++++++++++++----------------
 flang/test/Semantics/definable01.f90 | 15 +++++
 2 files changed, 50 insertions(+), 47 deletions(-)

diff --git a/flang/lib/Semantics/definable.cpp b/flang/lib/Semantics/definable.cpp
index b73290109248a0..1d6f7de3f66c3e 100644
--- a/flang/lib/Semantics/definable.cpp
+++ b/flang/lib/Semantics/definable.cpp
@@ -60,37 +60,31 @@ static std::optional<parser::Message> CheckDefinabilityInPureScope(
   return std::nullopt;
 }
 
-// When a DataRef contains pointers, gets the rightmost one (unless it is
-// the entity being defined, in which case the last pointer above it);
-// otherwise, returns the leftmost symbol.  The resulting symbol is the
-// relevant base object for definabiliy checking.  Examples:
-//   ptr1%ptr2        => ...     -> ptr1
-//   nonptr%ptr       => ...     -> nonptr
-//   nonptr%ptr       =  ...     -> ptr
-//   ptr1%ptr2        =  ...     -> ptr2
-//   ptr1%ptr2%nonptr =  ...     -> ptr2
-//   nonptr1%nonptr2  =  ...     -> nonptr1
-static const Symbol &GetRelevantSymbol(const evaluate::DataRef &dataRef,
-    bool isPointerDefinition, bool acceptAllocatable) {
-  if (isPointerDefinition) {
-    if (const auto *component{std::get_if<evaluate::Component>(&dataRef.u)}) {
-      if (IsPointer(component->GetLastSymbol()) ||
-          (acceptAllocatable && IsAllocatable(component->GetLastSymbol()))) {
-        return GetRelevantSymbol(component->base(), false, false);
+// True when the object being defined is not a subobject of the base
+// object, e.g. X%PTR = 1., X%PTR%PTR2 => T (but not X%PTR => T).
+// F'2023 9.4.2p5
+static bool DefinesComponentPointerTarget(
+    const evaluate::DataRef &dataRef, DefinabilityFlags flags) {
+  if (const auto *component{std::get_if<evaluate::Component>(&dataRef.u)}) {
+    const Symbol &compSym{component->GetLastSymbol()};
+    if (IsPointer(compSym) ||
+        (flags.test(DefinabilityFlag::AcceptAllocatable) &&
+            IsAllocatable(compSym))) {
+      if (!flags.test(DefinabilityFlag::PointerDefinition)) {
+        return true;
       }
     }
-  }
-  if (const Symbol * lastPointer{GetLastPointerSymbol(dataRef)}) {
-    return *lastPointer;
+    flags.reset(DefinabilityFlag::PointerDefinition);
+    return DefinesComponentPointerTarget(component->base(), flags);
   } else {
-    return dataRef.GetFirstSymbol();
+    return false;
   }
 }
 
 // Check the leftmost (or only) symbol from a data-ref or expression.
 static std::optional<parser::Message> WhyNotDefinableBase(parser::CharBlock at,
     const Scope &scope, DefinabilityFlags flags, const Symbol &original,
-    bool isWholeSymbol) {
+    bool isWholeSymbol, bool isComponentPointerTarget) {
   const Symbol &ultimate{original.GetUltimate()};
   bool isPointerDefinition{flags.test(DefinabilityFlag::PointerDefinition)};
   bool acceptAllocatable{flags.test(DefinabilityFlag::AcceptAllocatable)};
@@ -104,12 +98,14 @@ static std::optional<parser::Message> WhyNotDefinableBase(parser::CharBlock at,
           "Construct association '%s' has a vector subscript"_en_US, original);
     } else if (auto dataRef{evaluate::ExtractDataRef(
                    *association->expr(), true, true)}) {
-      return WhyNotDefinableBase(at, scope, flags,
-          GetRelevantSymbol(*dataRef, isPointerDefinition, acceptAllocatable),
-          isWholeSymbol);
+      return WhyNotDefinableBase(at, scope, flags, dataRef->GetFirstSymbol(),
+          isWholeSymbol &&
+              std::holds_alternative<evaluate::SymbolRef>(dataRef->u),
+          isComponentPointerTarget ||
+              DefinesComponentPointerTarget(*dataRef, flags));
     }
   }
-  if (isTargetDefinition) {
+  if (isTargetDefinition || isComponentPointerTarget) {
   } else if (!isPointerDefinition && !IsVariableName(ultimate)) {
     return BlameSymbol(at, "'%s' is not a variable"_en_US, original);
   } else if (IsProtected(ultimate) && IsUseAssociated(original, scope)) {
@@ -121,7 +117,7 @@ static std::optional<parser::Message> WhyNotDefinableBase(parser::CharBlock at,
   }
   if (const Scope * pure{FindPureProcedureContaining(scope)}) {
     // Additional checking for pure subprograms.
-    if (!isTargetDefinition) {
+    if (!isTargetDefinition || isComponentPointerTarget) {
       if (auto msg{CheckDefinabilityInPureScope(
               at, original, ultimate, scope, *pure)}) {
         return msg;
@@ -222,35 +218,24 @@ static std::optional<parser::Message> WhyNotDefinableLast(parser::CharBlock at,
 static std::optional<parser::Message> WhyNotDefinable(parser::CharBlock at,
     const Scope &scope, DefinabilityFlags flags,
     const evaluate::DataRef &dataRef) {
-  const Symbol &base{GetRelevantSymbol(dataRef,
-      flags.test(DefinabilityFlag::PointerDefinition),
-      flags.test(DefinabilityFlag::AcceptAllocatable))};
-  if (auto whyNot{WhyNotDefinableBase(at, scope, flags, base,
-          std::holds_alternative<evaluate::SymbolRef>(dataRef.u))}) {
+  if (auto whyNot{
+          WhyNotDefinableBase(at, scope, flags, dataRef.GetFirstSymbol(),
+              std::holds_alternative<evaluate::SymbolRef>(dataRef.u),
+              DefinesComponentPointerTarget(dataRef, flags))}) {
     return whyNot;
   } else {
     return WhyNotDefinableLast(at, scope, flags, dataRef.GetLastSymbol());
   }
 }
 
-// Checks a NOPASS procedure pointer component
-static std::optional<parser::Message> WhyNotDefinable(parser::CharBlock at,
-    const Scope &scope, DefinabilityFlags flags,
-    const evaluate::Component &component) {
-  const evaluate::DataRef &dataRef{component.base()};
-  const Symbol &base{GetRelevantSymbol(dataRef, false, false)};
-  DefinabilityFlags baseFlags{flags};
-  baseFlags.reset(DefinabilityFlag::PointerDefinition);
-  return WhyNotDefinableBase(at, scope, baseFlags, base,
-      std::holds_alternative<evaluate::SymbolRef>(dataRef.u));
-}
-
 std::optional<parser::Message> WhyNotDefinable(parser::CharBlock at,
     const Scope &scope, DefinabilityFlags flags, const Symbol &original) {
-  if (auto base{WhyNotDefinableBase(at, scope, flags, original, true)}) {
+  if (auto base{WhyNotDefinableBase(at, scope, flags, original,
+          /*isWholeSymbol=*/true, /*isComponentPointerTarget=*/false)}) {
     return base;
+  } else {
+    return WhyNotDefinableLast(at, scope, flags, original);
   }
-  return WhyNotDefinableLast(at, scope, flags, original);
 }
 
 class DuplicatedSubscriptFinder
@@ -370,7 +355,10 @@ std::optional<parser::Message> WhyNotDefinable(parser::CharBlock at,
               *procSym, expr.AsFortran());
         }
         if (const auto *component{procDesignator->GetComponent()}) {
-          return WhyNotDefinable(at, scope, flags, *component);
+          flags.reset(DefinabilityFlag::PointerDefinition);
+          return WhyNotDefinableBase(at, scope, flags,
+              component->base().GetFirstSymbol(), false,
+              DefinesComponentPointerTarget(component->base(), flags));
         } else {
           return WhyNotDefinable(at, scope, flags, *procSym);
         }
diff --git a/flang/test/Semantics/definable01.f90 b/flang/test/Semantics/definable01.f90
index c0f10668fb480e..ff71b419fa9713 100644
--- a/flang/test/Semantics/definable01.f90
+++ b/flang/test/Semantics/definable01.f90
@@ -25,6 +25,10 @@ module m
     real :: x2
   end type
   type(t2) :: t2static
+  type list
+    real a
+    type(list), pointer :: prev, next
+  end type
   character(*), parameter :: internal = '0'
  contains
   subroutine test1(dummy)
@@ -97,4 +101,15 @@ subroutine test5(np)
     !CHECK: because: 'np' is an INTENT(IN) dummy argument
     nullify(np%ptr)
   end
+  pure function test6(lp)
+    type(list), pointer :: lp
+    !CHECK: error: The left-hand side of a pointer assignment is not definable
+    !CHECK: because: 'lp' may not be defined in pure subprogram 'test6' because it is a POINTER dummy argument of a pure function
+    lp%next%next => null()
+  end
+  pure subroutine test7(lp)
+    type(list), pointer :: lp
+    !CHECK-NOT: error:
+    lp%next%next => null()
+  end
 end module



More information about the flang-commits mailing list