[flang-commits] [flang] [flang] Fix spurious error on defined assignment in PURE (PR #139186)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Mon May 12 12:09:31 PDT 2025


https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/139186

>From 4776a6a226254dc6d97428d986c7907edc672b5a Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Thu, 8 May 2025 17:46:35 -0700
Subject: [PATCH] [flang] Fix spurious error on defined assignment in PURE

An assignment to a whole polymorphic object in a PURE subprogram
that is implemented by means of a defined assignment procedure
shouldn't be subjected to the same definability checks as it
would be for an intrinsic assignment (which would also require
it to be allocatable).

Fixes https://github.com/llvm/llvm-project/issues/139129.
---
 flang/include/flang/Evaluate/tools.h       | 31 ++++++-----------
 flang/lib/Evaluate/tools.cpp               | 38 +++++++++++++++-----
 flang/lib/Semantics/assignment.cpp         |  5 +++
 flang/lib/Semantics/check-deallocate.cpp   |  6 ++--
 flang/lib/Semantics/check-declarations.cpp |  4 +--
 flang/lib/Semantics/definable.cpp          | 40 +++++++++++-----------
 flang/lib/Semantics/definable.h            |  2 +-
 flang/lib/Semantics/expression.cpp         |  6 ++--
 flang/test/Semantics/assign11.f90          |  6 ++--
 flang/test/Semantics/bug139129.f90         | 17 +++++++++
 flang/test/Semantics/call28.f90            |  4 +--
 flang/test/Semantics/deallocate07.f90      |  8 ++---
 flang/test/Semantics/declarations05.f90    |  2 +-
 13 files changed, 101 insertions(+), 68 deletions(-)
 create mode 100644 flang/test/Semantics/bug139129.f90

diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 922af4190822d..14baa0371231c 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -502,42 +502,31 @@ template <typename A> std::optional<Substring> ExtractSubstring(const A &x) {
 
 // If an expression is simply a whole symbol data designator,
 // extract and return that symbol, else null.
+const Symbol *UnwrapWholeSymbolDataRef(const DataRef &);
+const Symbol *UnwrapWholeSymbolDataRef(const std::optional<DataRef> &);
 template <typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) {
-  if (auto dataRef{ExtractDataRef(x)}) {
-    if (const SymbolRef * p{std::get_if<SymbolRef>(&dataRef->u)}) {
-      return &p->get();
-    }
-  }
-  return nullptr;
+  return UnwrapWholeSymbolDataRef(ExtractDataRef(x));
 }
 
 // If an expression is a whole symbol or a whole component desginator,
 // extract and return that symbol, else null.
+const Symbol *UnwrapWholeSymbolOrComponentDataRef(const DataRef &);
+const Symbol *UnwrapWholeSymbolOrComponentDataRef(
+    const std::optional<DataRef> &);
 template <typename A>
 const Symbol *UnwrapWholeSymbolOrComponentDataRef(const A &x) {
-  if (auto dataRef{ExtractDataRef(x)}) {
-    if (const SymbolRef * p{std::get_if<SymbolRef>(&dataRef->u)}) {
-      return &p->get();
-    } else if (const Component * c{std::get_if<Component>(&dataRef->u)}) {
-      if (c->base().Rank() == 0) {
-        return &c->GetLastSymbol();
-      }
-    }
-  }
-  return nullptr;
+  return UnwrapWholeSymbolOrComponentDataRef(ExtractDataRef(x));
 }
 
 // If an expression is a whole symbol or a whole component designator,
 // potentially followed by an image selector, extract and return that symbol,
 // else null.
 const Symbol *UnwrapWholeSymbolOrComponentOrCoarrayRef(const DataRef &);
+const Symbol *UnwrapWholeSymbolOrComponentOrCoarrayRef(
+    const std::optional<DataRef> &);
 template <typename A>
 const Symbol *UnwrapWholeSymbolOrComponentOrCoarrayRef(const A &x) {
-  if (auto dataRef{ExtractDataRef(x)}) {
-    return UnwrapWholeSymbolOrComponentOrCoarrayRef(*dataRef);
-  } else {
-    return nullptr;
-  }
+  return UnwrapWholeSymbolOrComponentOrCoarrayRef(ExtractDataRef(x));
 }
 
 // GetFirstSymbol(A%B%C[I]%D) -> A
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index d39e4c42928f3..641dead6c55d5 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -1320,17 +1320,39 @@ std::optional<parser::MessageFixedText> CheckProcCompatibility(bool isCall,
   return msg;
 }
 
+const Symbol *UnwrapWholeSymbolDataRef(const DataRef &dataRef) {
+  const SymbolRef *p{std::get_if<SymbolRef>(&dataRef.u)};
+  return p ? &p->get() : nullptr;
+}
+
+const Symbol *UnwrapWholeSymbolDataRef(const std::optional<DataRef> &dataRef) {
+  return dataRef ? UnwrapWholeSymbolDataRef(*dataRef) : nullptr;
+}
+
+const Symbol *UnwrapWholeSymbolOrComponentDataRef(const DataRef &dataRef) {
+  if (const Component * c{std::get_if<Component>(&dataRef.u)}) {
+    return c->base().Rank() == 0 ? &c->GetLastSymbol() : nullptr;
+  } else {
+    return UnwrapWholeSymbolDataRef(dataRef);
+  }
+}
+
+const Symbol *UnwrapWholeSymbolOrComponentDataRef(
+    const std::optional<DataRef> &dataRef) {
+  return dataRef ? UnwrapWholeSymbolOrComponentDataRef(*dataRef) : nullptr;
+}
+
 const Symbol *UnwrapWholeSymbolOrComponentOrCoarrayRef(const DataRef &dataRef) {
-  if (const SymbolRef * p{std::get_if<SymbolRef>(&dataRef.u)}) {
-    return &p->get();
-  } else if (const Component * c{std::get_if<Component>(&dataRef.u)}) {
-    if (c->base().Rank() == 0) {
-      return &c->GetLastSymbol();
-    }
-  } else if (const CoarrayRef * c{std::get_if<CoarrayRef>(&dataRef.u)}) {
+  if (const CoarrayRef * c{std::get_if<CoarrayRef>(&dataRef.u)}) {
     return UnwrapWholeSymbolOrComponentOrCoarrayRef(c->base());
+  } else {
+    return UnwrapWholeSymbolOrComponentDataRef(dataRef);
   }
-  return nullptr;
+}
+
+const Symbol *UnwrapWholeSymbolOrComponentOrCoarrayRef(
+    const std::optional<DataRef> &dataRef) {
+  return dataRef ? UnwrapWholeSymbolOrComponentOrCoarrayRef(*dataRef) : nullptr;
 }
 
 // GetLastPointerSymbol()
diff --git a/flang/lib/Semantics/assignment.cpp b/flang/lib/Semantics/assignment.cpp
index 935f5a03bdb6a..6e55d0210ee0e 100644
--- a/flang/lib/Semantics/assignment.cpp
+++ b/flang/lib/Semantics/assignment.cpp
@@ -72,6 +72,11 @@ void AssignmentContext::Analyze(const parser::AssignmentStmt &stmt) {
         std::holds_alternative<evaluate::ProcedureRef>(assignment->u)};
     if (isDefinedAssignment) {
       flags.set(DefinabilityFlag::AllowEventLockOrNotifyType);
+    } else if (const Symbol *
+        whole{evaluate::UnwrapWholeSymbolOrComponentDataRef(lhs)}) {
+      if (IsAllocatable(whole->GetUltimate())) {
+        flags.set(DefinabilityFlag::PotentialDeallocation);
+      }
     }
     if (auto whyNot{WhyNotDefinable(lhsLoc, scope, flags, lhs)}) {
       if (whyNot->IsFatal()) {
diff --git a/flang/lib/Semantics/check-deallocate.cpp b/flang/lib/Semantics/check-deallocate.cpp
index 3bcd4d87b0906..c45b58586853b 100644
--- a/flang/lib/Semantics/check-deallocate.cpp
+++ b/flang/lib/Semantics/check-deallocate.cpp
@@ -36,7 +36,8 @@ void DeallocateChecker::Leave(const parser::DeallocateStmt &deallocateStmt) {
               } else if (auto whyNot{WhyNotDefinable(name.source,
                              context_.FindScope(name.source),
                              {DefinabilityFlag::PointerDefinition,
-                                 DefinabilityFlag::AcceptAllocatable},
+                                 DefinabilityFlag::AcceptAllocatable,
+                                 DefinabilityFlag::PotentialDeallocation},
                              *symbol)}) {
                 // Catch problems with non-definability of the
                 // pointer/allocatable
@@ -74,7 +75,8 @@ void DeallocateChecker::Leave(const parser::DeallocateStmt &deallocateStmt) {
                 } else if (auto whyNot{WhyNotDefinable(source,
                                context_.FindScope(source),
                                {DefinabilityFlag::PointerDefinition,
-                                   DefinabilityFlag::AcceptAllocatable},
+                                   DefinabilityFlag::AcceptAllocatable,
+                                   DefinabilityFlag::PotentialDeallocation},
                                *expr)}) {
                   context_
                       .Say(source,
diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp
index 318085518cc57..c3a228f3ab8a9 100644
--- a/flang/lib/Semantics/check-declarations.cpp
+++ b/flang/lib/Semantics/check-declarations.cpp
@@ -949,8 +949,8 @@ void CheckHelper::CheckObjectEntity(
       !IsFunctionResult(symbol) /*ditto*/) {
     // Check automatically deallocated local variables for possible
     // problems with finalization in PURE.
-    if (auto whyNot{
-            WhyNotDefinable(symbol.name(), symbol.owner(), {}, symbol)}) {
+    if (auto whyNot{WhyNotDefinable(symbol.name(), symbol.owner(),
+            {DefinabilityFlag::PotentialDeallocation}, symbol)}) {
       if (auto *msg{messages_.Say(
               "'%s' may not be a local variable in a pure subprogram"_err_en_US,
               symbol.name())}) {
diff --git a/flang/lib/Semantics/definable.cpp b/flang/lib/Semantics/definable.cpp
index 99a31553f2782..08cb268b318ae 100644
--- a/flang/lib/Semantics/definable.cpp
+++ b/flang/lib/Semantics/definable.cpp
@@ -193,6 +193,15 @@ static std::optional<parser::Message> WhyNotDefinableLast(parser::CharBlock at,
       return WhyNotDefinableLast(at, scope, flags, dataRef->GetLastSymbol());
     }
   }
+  auto dyType{evaluate::DynamicType::From(ultimate)};
+  const auto *inPure{FindPureProcedureContaining(scope)};
+  if (inPure && !flags.test(DefinabilityFlag::PolymorphicOkInPure) &&
+      flags.test(DefinabilityFlag::PotentialDeallocation) && dyType &&
+      dyType->IsPolymorphic()) {
+    return BlameSymbol(at,
+        "'%s' is a whole polymorphic object in a pure subprogram"_en_US,
+        original);
+  }
   if (flags.test(DefinabilityFlag::PointerDefinition)) {
     if (flags.test(DefinabilityFlag::AcceptAllocatable)) {
       if (!IsAllocatableOrObjectPointer(&ultimate)) {
@@ -210,26 +219,17 @@ static std::optional<parser::Message> WhyNotDefinableLast(parser::CharBlock at,
         "'%s' is an entity with either an EVENT_TYPE or LOCK_TYPE"_en_US,
         original);
   }
-  if (FindPureProcedureContaining(scope)) {
-    if (auto dyType{evaluate::DynamicType::From(ultimate)}) {
-      if (!flags.test(DefinabilityFlag::PolymorphicOkInPure)) {
-        if (dyType->IsPolymorphic()) { // C1596
-          return BlameSymbol(
-              at, "'%s' is polymorphic in a pure subprogram"_en_US, original);
-        }
-      }
-      if (const Symbol * impure{HasImpureFinal(ultimate)}) {
-        return BlameSymbol(at, "'%s' has an impure FINAL procedure '%s'"_en_US,
-            original, impure->name());
-      }
+  if (dyType && inPure) {
+    if (const Symbol * impure{HasImpureFinal(ultimate)}) {
+      return BlameSymbol(at, "'%s' has an impure FINAL procedure '%s'"_en_US,
+          original, impure->name());
+    }
+    if (!flags.test(DefinabilityFlag::PolymorphicOkInPure)) {
       if (const DerivedTypeSpec * derived{GetDerivedTypeSpec(dyType)}) {
-        if (!flags.test(DefinabilityFlag::PolymorphicOkInPure)) {
-          if (auto bad{
-                  FindPolymorphicAllocatablePotentialComponent(*derived)}) {
-            return BlameSymbol(at,
-                "'%s' has polymorphic component '%s' in a pure subprogram"_en_US,
-                original, bad.BuildResultDesignatorName());
-          }
+        if (auto bad{FindPolymorphicAllocatablePotentialComponent(*derived)}) {
+          return BlameSymbol(at,
+              "'%s' has polymorphic component '%s' in a pure subprogram"_en_US,
+              original, bad.BuildResultDesignatorName());
         }
       }
     }
@@ -243,7 +243,7 @@ static std::optional<parser::Message> WhyNotDefinable(parser::CharBlock at,
     const evaluate::DataRef &dataRef) {
   auto whyNotBase{
       WhyNotDefinableBase(at, scope, flags, dataRef.GetFirstSymbol(),
-          std::holds_alternative<evaluate::SymbolRef>(dataRef.u),
+          evaluate::UnwrapWholeSymbolDataRef(dataRef) != nullptr,
           DefinesComponentPointerTarget(dataRef, flags))};
   if (!whyNotBase || !whyNotBase->IsFatal()) {
     if (auto whyNotLast{
diff --git a/flang/lib/Semantics/definable.h b/flang/lib/Semantics/definable.h
index 902702dbccbf3..0d027961417be 100644
--- a/flang/lib/Semantics/definable.h
+++ b/flang/lib/Semantics/definable.h
@@ -33,7 +33,7 @@ ENUM_CLASS(DefinabilityFlag,
     SourcedAllocation, // ALLOCATE(a,SOURCE=)
     PolymorphicOkInPure, // don't check for polymorphic type in pure subprogram
     DoNotNoteDefinition, // context does not imply definition
-    AllowEventLockOrNotifyType)
+    AllowEventLockOrNotifyType, PotentialDeallocation)
 
 using DefinabilityFlags =
     common::EnumSet<DefinabilityFlag, DefinabilityFlag_enumSize>;
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index 0659536aab98c..2c89bcd981f6d 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -3391,15 +3391,15 @@ const Assignment *ExpressionAnalyzer::Analyze(const parser::AssignmentStmt &x) {
             const Symbol *lastWhole{
                 lastWhole0 ? &ResolveAssociations(*lastWhole0) : nullptr};
             if (!lastWhole || !IsAllocatable(*lastWhole)) {
-              Say("Left-hand side of assignment may not be polymorphic unless assignment is to an entire allocatable"_err_en_US);
+              Say("Left-hand side of intrinsic assignment may not be polymorphic unless assignment is to an entire allocatable"_err_en_US);
             } else if (evaluate::IsCoarray(*lastWhole)) {
-              Say("Left-hand side of assignment may not be polymorphic if it is a coarray"_err_en_US);
+              Say("Left-hand side of intrinsic assignment may not be polymorphic if it is a coarray"_err_en_US);
             }
           }
           if (auto *derived{GetDerivedTypeSpec(*dyType)}) {
             if (auto iter{FindAllocatableUltimateComponent(*derived)}) {
               if (ExtractCoarrayRef(lhs)) {
-                Say("Left-hand side of assignment must not be coindexed due to allocatable ultimate component '%s'"_err_en_US,
+                Say("Left-hand side of intrinsic assignment must not be coindexed due to allocatable ultimate component '%s'"_err_en_US,
                     iter.BuildResultDesignatorName());
               }
             }
diff --git a/flang/test/Semantics/assign11.f90 b/flang/test/Semantics/assign11.f90
index 37216526b5f33..9d70d7109e75e 100644
--- a/flang/test/Semantics/assign11.f90
+++ b/flang/test/Semantics/assign11.f90
@@ -9,10 +9,10 @@ program test
   end type
   type(t) auc[*]
   pa = 1 ! ok
-  !ERROR: Left-hand side of assignment may not be polymorphic unless assignment is to an entire allocatable
+  !ERROR: Left-hand side of intrinsic assignment may not be polymorphic unless assignment is to an entire allocatable
   pp = 1
-  !ERROR: Left-hand side of assignment may not be polymorphic if it is a coarray
+  !ERROR: Left-hand side of intrinsic assignment may not be polymorphic if it is a coarray
   pac = 1
-  !ERROR: Left-hand side of assignment must not be coindexed due to allocatable ultimate component '%a'
+  !ERROR: Left-hand side of intrinsic assignment must not be coindexed due to allocatable ultimate component '%a'
   auc[1] = t()
 end
diff --git a/flang/test/Semantics/bug139129.f90 b/flang/test/Semantics/bug139129.f90
new file mode 100644
index 0000000000000..2f0f865854706
--- /dev/null
+++ b/flang/test/Semantics/bug139129.f90
@@ -0,0 +1,17 @@
+!RUN: %flang_fc1 -fsyntax-only %s
+module m
+  type t
+   contains
+    procedure asst
+    generic :: assignment(=) => asst
+  end type
+ contains
+  pure subroutine asst(lhs, rhs)
+    class(t), intent(in out) :: lhs
+    class(t), intent(in) :: rhs
+  end
+  pure subroutine test(x, y)
+    class(t), intent(in out) :: x, y
+    x = y ! spurious definability error
+  end
+end
diff --git a/flang/test/Semantics/call28.f90 b/flang/test/Semantics/call28.f90
index 51430853d663f..f133276f7547e 100644
--- a/flang/test/Semantics/call28.f90
+++ b/flang/test/Semantics/call28.f90
@@ -11,9 +11,7 @@ pure subroutine s1(x)
   end subroutine
   pure subroutine s2(x)
     class(t), intent(in out) :: x
-    !ERROR: Left-hand side of assignment may not be polymorphic unless assignment is to an entire allocatable
-    !ERROR: Left-hand side of assignment is not definable
-    !BECAUSE: 'x' is polymorphic in a pure subprogram
+    !ERROR: Left-hand side of intrinsic assignment may not be polymorphic unless assignment is to an entire allocatable
     x = t()
   end subroutine
   pure subroutine s3(x)
diff --git a/flang/test/Semantics/deallocate07.f90 b/flang/test/Semantics/deallocate07.f90
index 154c680f47c82..6dcf20e82cf0d 100644
--- a/flang/test/Semantics/deallocate07.f90
+++ b/flang/test/Semantics/deallocate07.f90
@@ -19,11 +19,11 @@ pure subroutine subr(pp1, pp2, mp2)
     !ERROR: Name in DEALLOCATE statement is not definable
     !BECAUSE: 'mv1' may not be defined in pure subprogram 'subr' because it is host-associated
     deallocate(mv1%pc)
-    !ERROR: Object in DEALLOCATE statement is not deallocatable
-    !BECAUSE: 'pp1' is polymorphic in a pure subprogram
+    !ERROR: Name in DEALLOCATE statement is not definable
+    !BECAUSE: 'pp1' is a whole polymorphic object in a pure subprogram
     deallocate(pp1)
-    !ERROR: Object in DEALLOCATE statement is not deallocatable
-    !BECAUSE: 'pc' is polymorphic in a pure subprogram
+    !ERROR: Name in DEALLOCATE statement is not definable
+    !BECAUSE: 'pc' is a whole polymorphic object in a pure subprogram
     deallocate(pp2%pc)
     !ERROR: Object in DEALLOCATE statement is not deallocatable
     !BECAUSE: 'mp2' has polymorphic component '%pc' in a pure subprogram
diff --git a/flang/test/Semantics/declarations05.f90 b/flang/test/Semantics/declarations05.f90
index b6dab7aeea0bc..b1e3d3c773160 100644
--- a/flang/test/Semantics/declarations05.f90
+++ b/flang/test/Semantics/declarations05.f90
@@ -22,7 +22,7 @@ impure subroutine final(x)
   end
   pure subroutine test
     !ERROR: 'x0' may not be a local variable in a pure subprogram
-    !BECAUSE: 'x0' is polymorphic in a pure subprogram
+    !BECAUSE: 'x0' is a whole polymorphic object in a pure subprogram
     class(t0), allocatable :: x0
     !ERROR: 'x1' may not be a local variable in a pure subprogram
     !BECAUSE: 'x1' has an impure FINAL procedure 'final'



More information about the flang-commits mailing list