[flang-commits] [flang] [flang] Get base objects right in definability checker (PR #78854)

via flang-commits flang-commits at lists.llvm.org
Sat Jan 20 09:50:17 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-semantics

Author: Peter Klausler (klausler)

<details>
<summary>Changes</summary>

The utility function GetRelevantObject() seems to be just wrong for definability checks for the "base object" of a designator, and that's all for which it is (now?) used.  This leads to some false error messages in Whizard when data-refs with multiple pointer components are defined.  Simplify, and add more test cases.

---
Full diff: https://github.com/llvm/llvm-project/pull/78854.diff


2 Files Affected:

- (modified) flang/lib/Semantics/definable.cpp (+8-48) 
- (modified) flang/test/Semantics/definable01.f90 (+15) 


``````````diff
diff --git a/flang/lib/Semantics/definable.cpp b/flang/lib/Semantics/definable.cpp
index b73290109248a0..69f9be671798bf 100644
--- a/flang/lib/Semantics/definable.cpp
+++ b/flang/lib/Semantics/definable.cpp
@@ -60,33 +60,6 @@ static std::optional<parser::Message> CheckDefinabilityInPureScope(
   return std::nullopt;
 }
 
-// When a DataRef contains pointers, gets the rightmost one (unless it is
-// the entity being defined, in which case the last pointer above it);
-// otherwise, returns the leftmost symbol.  The resulting symbol is the
-// relevant base object for definabiliy checking.  Examples:
-//   ptr1%ptr2        => ...     -> ptr1
-//   nonptr%ptr       => ...     -> nonptr
-//   nonptr%ptr       =  ...     -> ptr
-//   ptr1%ptr2        =  ...     -> ptr2
-//   ptr1%ptr2%nonptr =  ...     -> ptr2
-//   nonptr1%nonptr2  =  ...     -> nonptr1
-static const Symbol &GetRelevantSymbol(const evaluate::DataRef &dataRef,
-    bool isPointerDefinition, bool acceptAllocatable) {
-  if (isPointerDefinition) {
-    if (const auto *component{std::get_if<evaluate::Component>(&dataRef.u)}) {
-      if (IsPointer(component->GetLastSymbol()) ||
-          (acceptAllocatable && IsAllocatable(component->GetLastSymbol()))) {
-        return GetRelevantSymbol(component->base(), false, false);
-      }
-    }
-  }
-  if (const Symbol * lastPointer{GetLastPointerSymbol(dataRef)}) {
-    return *lastPointer;
-  } else {
-    return dataRef.GetFirstSymbol();
-  }
-}
-
 // Check the leftmost (or only) symbol from a data-ref or expression.
 static std::optional<parser::Message> WhyNotDefinableBase(parser::CharBlock at,
     const Scope &scope, DefinabilityFlags flags, const Symbol &original,
@@ -104,9 +77,8 @@ static std::optional<parser::Message> WhyNotDefinableBase(parser::CharBlock at,
           "Construct association '%s' has a vector subscript"_en_US, original);
     } else if (auto dataRef{evaluate::ExtractDataRef(
                    *association->expr(), true, true)}) {
-      return WhyNotDefinableBase(at, scope, flags,
-          GetRelevantSymbol(*dataRef, isPointerDefinition, acceptAllocatable),
-          isWholeSymbol);
+      return WhyNotDefinableBase(
+          at, scope, flags, dataRef->GetFirstSymbol(), isWholeSymbol);
     }
   }
   if (isTargetDefinition) {
@@ -222,29 +194,15 @@ static std::optional<parser::Message> WhyNotDefinableLast(parser::CharBlock at,
 static std::optional<parser::Message> WhyNotDefinable(parser::CharBlock at,
     const Scope &scope, DefinabilityFlags flags,
     const evaluate::DataRef &dataRef) {
-  const Symbol &base{GetRelevantSymbol(dataRef,
-      flags.test(DefinabilityFlag::PointerDefinition),
-      flags.test(DefinabilityFlag::AcceptAllocatable))};
-  if (auto whyNot{WhyNotDefinableBase(at, scope, flags, base,
-          std::holds_alternative<evaluate::SymbolRef>(dataRef.u))}) {
+  if (auto whyNot{
+          WhyNotDefinableBase(at, scope, flags, dataRef.GetFirstSymbol(),
+              std::holds_alternative<evaluate::SymbolRef>(dataRef.u))}) {
     return whyNot;
   } else {
     return WhyNotDefinableLast(at, scope, flags, dataRef.GetLastSymbol());
   }
 }
 
-// Checks a NOPASS procedure pointer component
-static std::optional<parser::Message> WhyNotDefinable(parser::CharBlock at,
-    const Scope &scope, DefinabilityFlags flags,
-    const evaluate::Component &component) {
-  const evaluate::DataRef &dataRef{component.base()};
-  const Symbol &base{GetRelevantSymbol(dataRef, false, false)};
-  DefinabilityFlags baseFlags{flags};
-  baseFlags.reset(DefinabilityFlag::PointerDefinition);
-  return WhyNotDefinableBase(at, scope, baseFlags, base,
-      std::holds_alternative<evaluate::SymbolRef>(dataRef.u));
-}
-
 std::optional<parser::Message> WhyNotDefinable(parser::CharBlock at,
     const Scope &scope, DefinabilityFlags flags, const Symbol &original) {
   if (auto base{WhyNotDefinableBase(at, scope, flags, original, true)}) {
@@ -370,7 +328,9 @@ std::optional<parser::Message> WhyNotDefinable(parser::CharBlock at,
               *procSym, expr.AsFortran());
         }
         if (const auto *component{procDesignator->GetComponent()}) {
-          return WhyNotDefinable(at, scope, flags, *component);
+          flags.reset(DefinabilityFlag::PointerDefinition);
+          return WhyNotDefinableBase(
+              at, scope, flags, component->GetFirstSymbol(), false);
         } else {
           return WhyNotDefinable(at, scope, flags, *procSym);
         }
diff --git a/flang/test/Semantics/definable01.f90 b/flang/test/Semantics/definable01.f90
index c0f10668fb480e..ff71b419fa9713 100644
--- a/flang/test/Semantics/definable01.f90
+++ b/flang/test/Semantics/definable01.f90
@@ -25,6 +25,10 @@ module m
     real :: x2
   end type
   type(t2) :: t2static
+  type list
+    real a
+    type(list), pointer :: prev, next
+  end type
   character(*), parameter :: internal = '0'
  contains
   subroutine test1(dummy)
@@ -97,4 +101,15 @@ subroutine test5(np)
     !CHECK: because: 'np' is an INTENT(IN) dummy argument
     nullify(np%ptr)
   end
+  pure function test6(lp)
+    type(list), pointer :: lp
+    !CHECK: error: The left-hand side of a pointer assignment is not definable
+    !CHECK: because: 'lp' may not be defined in pure subprogram 'test6' because it is a POINTER dummy argument of a pure function
+    lp%next%next => null()
+  end
+  pure subroutine test7(lp)
+    type(list), pointer :: lp
+    !CHECK-NOT: error:
+    lp%next%next => null()
+  end
 end module

``````````

</details>


https://github.com/llvm/llvm-project/pull/78854


More information about the flang-commits mailing list