[flang-commits] [flang] [flang] Make IsCoarray() more accurate (PR #121415)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Fri Jan 3 08:20:04 PST 2025


https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/121415

>From 99b8e3c2c687070ab3acc11af4077697580cb3ee Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Tue, 31 Dec 2024 13:20:27 -0800
Subject: [PATCH] [flang] Make IsCoarray() more accurate; fix ASSOCIATE coarray

A designator without cosubscripts can have subscripts, component
references, substrings, &c. and still have corank.  The current
IsCoarray() predicate only seems to work for whole variable/component
references.  This was breaking some cases of THIS_IMAGE().

Further, when checking the number of cosubscripts in a coarray reference,
allow for the possibility that the coarray might be an ASSOCIATE
construct entity.
---
 .../include/flang/Evaluate/characteristics.h  |  4 ++
 flang/include/flang/Evaluate/tools.h          | 25 +++++----
 flang/include/flang/Evaluate/variable.h       |  9 ++++
 flang/lib/Evaluate/characteristics.cpp        | 15 +++---
 flang/lib/Evaluate/tools.cpp                  |  8 +--
 flang/lib/Evaluate/variable.cpp               | 53 +++++++++++++++++++
 flang/lib/Semantics/check-call.cpp            |  4 +-
 flang/lib/Semantics/expression.cpp            |  4 +-
 flang/test/Semantics/resolve94.f90            |  7 +++
 flang/test/Semantics/this_image01.f90         |  4 ++
 10 files changed, 104 insertions(+), 29 deletions(-)

diff --git a/flang/include/flang/Evaluate/characteristics.h b/flang/include/flang/Evaluate/characteristics.h
index 11533a7259b055..357fc3e5952436 100644
--- a/flang/include/flang/Evaluate/characteristics.h
+++ b/flang/include/flang/Evaluate/characteristics.h
@@ -102,6 +102,10 @@ class TypeAndShape {
     }
     if (auto type{x.GetType()}) {
       TypeAndShape result{*type, GetShape(context, x, invariantOnly)};
+      result.corank_ = GetCorank(x);
+      if (result.corank_ > 0) {
+        result.attrs_.set(Attr::Coarray);
+      }
       if (type->category() == TypeCategory::Character) {
         if (const auto *chExpr{UnwrapExpr<Expr<SomeCharacter>>(x)}) {
           if (auto length{chExpr->LEN()}) {
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index f586c59d46e54c..b551eb0670e867 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -102,23 +102,22 @@ template <typename A> bool IsAssumedRank(const A *x) {
   return x && IsAssumedRank(*x);
 }
 
-// Predicate: true when an expression is a coarray (corank > 0)
-bool IsCoarray(const ActualArgument &);
-bool IsCoarray(const Symbol &);
-template <typename A> bool IsCoarray(const A &) { return false; }
-template <typename A> bool IsCoarray(const Designator<A> &designator) {
-  if (const auto *symbol{std::get_if<SymbolRef>(&designator.u)}) {
-    return IsCoarray(**symbol);
-  }
-  return false;
+int GetCorank(const ActualArgument &);
+int GetCorank(const Symbol &);
+template <typename A> int GetCorank(const A &) { return 0; }
+template <typename T> int GetCorank(const Designator<T> &designator) {
+  return designator.Corank();
 }
-template <typename T> bool IsCoarray(const Expr<T> &expr) {
-  return common::visit([](const auto &x) { return IsCoarray(x); }, expr.u);
+template <typename T> int GetCorank(const Expr<T> &expr) {
+  return common::visit([](const auto &x) { return GetCorank(x); }, expr.u);
 }
-template <typename A> bool IsCoarray(const std::optional<A> &x) {
-  return x && IsCoarray(*x);
+template <typename A> int GetCorank(const std::optional<A> &x) {
+  return x ? GetCorank(*x) : 0;
 }
 
+// Predicate: true when an expression is a coarray (corank > 0)
+template <typename A> bool IsCoarray(const A &x) { return GetCorank(x) > 0; }
+
 // Generalizing packagers: these take operations and expressions of more
 // specific types and wrap them in Expr<> containers of more abstract types.
 
diff --git a/flang/include/flang/Evaluate/variable.h b/flang/include/flang/Evaluate/variable.h
index 9565826dbfaea4..178ce80cafded5 100644
--- a/flang/include/flang/Evaluate/variable.h
+++ b/flang/include/flang/Evaluate/variable.h
@@ -51,6 +51,7 @@ template <typename T> struct Variable;
 struct BaseObject {
   EVALUATE_UNION_CLASS_BOILERPLATE(BaseObject)
   int Rank() const;
+  int Corank() const;
   std::optional<Expr<SubscriptInteger>> LEN() const;
   llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
   const Symbol *symbol() const {
@@ -84,6 +85,7 @@ class Component {
   SymbolRef &symbol() { return symbol_; }
 
   int Rank() const;
+  int Corank() const;
   const Symbol &GetFirstSymbol() const;
   const Symbol &GetLastSymbol() const { return symbol_; }
   std::optional<Expr<SubscriptInteger>> LEN() const;
@@ -116,6 +118,7 @@ class NamedEntity {
   Component *UnwrapComponent();
 
   int Rank() const;
+  int Corank() const;
   std::optional<Expr<SubscriptInteger>> LEN() const;
   bool operator==(const NamedEntity &) const;
   llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
@@ -224,6 +227,7 @@ class ArrayRef {
   }
 
   int Rank() const;
+  int Corank() const;
   const Symbol &GetFirstSymbol() const;
   const Symbol &GetLastSymbol() const;
   std::optional<Expr<SubscriptInteger>> LEN() const;
@@ -271,6 +275,7 @@ class CoarrayRef {
   CoarrayRef &set_team(Expr<SomeInteger> &&, bool isTeamNumber = false);
 
   int Rank() const;
+  int Corank() const { return 0; }
   const Symbol &GetFirstSymbol() const;
   const Symbol &GetLastSymbol() const;
   NamedEntity GetBase() const;
@@ -294,6 +299,7 @@ class CoarrayRef {
 struct DataRef {
   EVALUATE_UNION_CLASS_BOILERPLATE(DataRef)
   int Rank() const;
+  int Corank() const;
   const Symbol &GetFirstSymbol() const;
   const Symbol &GetLastSymbol() const;
   std::optional<Expr<SubscriptInteger>> LEN() const;
@@ -331,6 +337,7 @@ class Substring {
   Parent &parent() { return parent_; }
 
   int Rank() const;
+  int Corank() const;
   template <typename A> const A *GetParentIf() const {
     return std::get_if<A>(&parent_);
   }
@@ -361,6 +368,7 @@ class ComplexPart {
   const DataRef &complex() const { return complex_; }
   Part part() const { return part_; }
   int Rank() const;
+  int Corank() const;
   const Symbol &GetFirstSymbol() const { return complex_.GetFirstSymbol(); }
   const Symbol &GetLastSymbol() const { return complex_.GetLastSymbol(); }
   bool operator==(const ComplexPart &) const;
@@ -396,6 +404,7 @@ template <typename T> class Designator {
 
   std::optional<DynamicType> GetType() const;
   int Rank() const;
+  int Corank() const;
   BaseObject GetBaseObject() const;
   const Symbol *GetLastSymbol() const;
   std::optional<Expr<SubscriptInteger>> LEN() const;
diff --git a/flang/lib/Evaluate/characteristics.cpp b/flang/lib/Evaluate/characteristics.cpp
index 324d6b8dde73b8..3912d1c4b47715 100644
--- a/flang/lib/Evaluate/characteristics.cpp
+++ b/flang/lib/Evaluate/characteristics.cpp
@@ -227,15 +227,14 @@ void TypeAndShape::AcquireAttrs(const semantics::Symbol &symbol) {
   } else if (semantics::IsAssumedSizeArray(symbol)) {
     attrs_.set(Attr::AssumedSize);
   }
+  if (int n{GetCorank(symbol)}) {
+    corank_ = n;
+    attrs_.set(Attr::Coarray);
+  }
   if (const auto *object{
-          symbol.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) {
-    corank_ = object->coshape().Rank();
-    if (object->IsAssumedRank()) {
-      attrs_.set(Attr::AssumedRank);
-    }
-    if (object->IsCoarray()) {
-      attrs_.set(Attr::Coarray);
-    }
+          symbol.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()};
+      object && object->IsAssumedRank()) {
+    attrs_.set(Attr::AssumedRank);
   }
 }
 
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 6299084d729b2d..bd52a789168ed8 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -906,13 +906,13 @@ bool IsAssumedRank(const ActualArgument &arg) {
   }
 }
 
-bool IsCoarray(const ActualArgument &arg) {
+int GetCorank(const ActualArgument &arg) {
   const auto *expr{arg.UnwrapExpr()};
-  return expr && IsCoarray(*expr);
+  return expr ? GetCorank(*expr) : 0;
 }
 
-bool IsCoarray(const Symbol &symbol) {
-  return GetAssociationRoot(symbol).Corank() > 0;
+int GetCorank(const Symbol &symbol) {
+  return GetAssociationRoot(symbol).Corank();
 }
 
 bool IsProcedureDesignator(const Expr<SomeType> &expr) {
diff --git a/flang/lib/Evaluate/variable.cpp b/flang/lib/Evaluate/variable.cpp
index 707a2065ca30a7..841d0f71ed0e2f 100644
--- a/flang/lib/Evaluate/variable.cpp
+++ b/flang/lib/Evaluate/variable.cpp
@@ -465,6 +465,59 @@ template <typename T> int Designator<T>::Rank() const {
       u);
 }
 
+// Corank()
+int BaseObject::Corank() const {
+  return common::visit(common::visitors{
+                           [](SymbolRef symbol) { return symbol->Corank(); },
+                           [](const StaticDataObject::Pointer &) { return 0; },
+                       },
+      u);
+}
+
+int Component::Corank() const {
+  if (int corank{symbol_->Corank()}; corank > 0) {
+    return corank;
+  }
+  return base().Corank();
+}
+
+int NamedEntity::Corank() const {
+  return common::visit(common::visitors{
+                           [](const SymbolRef s) { return s->Corank(); },
+                           [](const Component &c) { return c.Corank(); },
+                       },
+      u_);
+}
+
+int ArrayRef::Corank() const { return base().Corank(); }
+
+int DataRef::Corank() const {
+  return common::visit(common::visitors{
+                           [](SymbolRef symbol) { return symbol->Corank(); },
+                           [](const auto &x) { return x.Corank(); },
+                       },
+      u);
+}
+
+int Substring::Corank() const {
+  return common::visit(
+      common::visitors{
+          [](const DataRef &dataRef) { return dataRef.Corank(); },
+          [](const StaticDataObject::Pointer &) { return 0; },
+      },
+      parent_);
+}
+
+int ComplexPart::Corank() const { return complex_.Corank(); }
+
+template <typename T> int Designator<T>::Corank() const {
+  return common::visit(common::visitors{
+                           [](SymbolRef symbol) { return symbol->Corank(); },
+                           [](const auto &x) { return x.Corank(); },
+                       },
+      u);
+}
+
 // GetBaseObject(), GetFirstSymbol(), GetLastSymbol(), &c.
 const Symbol &Component::GetFirstSymbol() const {
   return base_.value().GetFirstSymbol();
diff --git a/flang/lib/Semantics/check-call.cpp b/flang/lib/Semantics/check-call.cpp
index 597c280a6df8bc..95df34b4a1f3e9 100644
--- a/flang/lib/Semantics/check-call.cpp
+++ b/flang/lib/Semantics/check-call.cpp
@@ -1622,8 +1622,8 @@ static void CheckImage_Index(evaluate::ActualArguments &arguments,
             evaluate::GetShape(arguments[1]->UnwrapExpr())}) {
       if (const auto *coarrayArgSymbol{UnwrapWholeSymbolOrComponentDataRef(
               arguments[0]->UnwrapExpr())}) {
-        const auto coarrayArgCorank = coarrayArgSymbol->Corank();
-        if (const auto subArrSize = evaluate::ToInt64(*subArrShape->front())) {
+        auto coarrayArgCorank{coarrayArgSymbol->Corank()};
+        if (auto subArrSize{evaluate::ToInt64(*subArrShape->front())}) {
           if (subArrSize != coarrayArgCorank) {
             messages.Say(arguments[1]->sourceLocation(),
                 "The size of 'SUB=' (%jd) for intrinsic 'image_index' must be equal to the corank of 'COARRAY=' (%d)"_err_en_US,
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index c2eb17c1ac8e5b..1274feb388721a 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -1506,9 +1506,9 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::CoindexedNamedObject &x) {
     if (cosubsOk && !reversed.empty()) {
       int numCosubscripts{static_cast<int>(cosubscripts.size())};
       const Symbol &symbol{reversed.front()};
-      if (numCosubscripts != symbol.Corank()) {
+      if (numCosubscripts != GetCorank(symbol)) {
         Say("'%s' has corank %d, but coindexed reference has %d cosubscripts"_err_en_US,
-            symbol.name(), symbol.Corank(), numCosubscripts);
+            symbol.name(), GetCorank(symbol), numCosubscripts);
       }
     }
     for (const auto &imageSelSpec :
diff --git a/flang/test/Semantics/resolve94.f90 b/flang/test/Semantics/resolve94.f90
index e47ab4a433829b..19c06ad0d16228 100644
--- a/flang/test/Semantics/resolve94.f90
+++ b/flang/test/Semantics/resolve94.f90
@@ -17,8 +17,15 @@ subroutine s1()
   intCoVar = 343
   ! OK
   rVar1 = rCoarray[1,2,3]
+  associate (x => rCoarray)
+    rVar1 = x[1,2,3] ! also ok
+  end associate
   !ERROR: 'rcoarray' has corank 3, but coindexed reference has 2 cosubscripts
   rVar1 = rCoarray[1,2]
+  associate (x => rCoarray)
+  !ERROR: 'x' has corank 3, but coindexed reference has 2 cosubscripts
+    rVar1 = x[1,2]
+  end associate
   !ERROR: Must have INTEGER type, but is REAL(4)
   rVar1 = rCoarray[1,2,3.4]
   !ERROR: Must have INTEGER type, but is REAL(4)
diff --git a/flang/test/Semantics/this_image01.f90 b/flang/test/Semantics/this_image01.f90
index 0e59aa3fa27c6b..efe39e2fee503c 100644
--- a/flang/test/Semantics/this_image01.f90
+++ b/flang/test/Semantics/this_image01.f90
@@ -17,6 +17,10 @@ subroutine test
   print *, this_image(coarray, team)
   print *, this_image(coarray, 1)
   print *, this_image(coarray, 1, team)
+  print *, this_image(coarray(1))
+  print *, this_image(coarray(1), team)
+  print *, this_image(coarray(1), 1)
+  print *, this_image(coarray(1), 1, team)
   print *, this_image(coscalar)
   print *, this_image(coscalar, team)
   print *, this_image(coscalar, 1)



More information about the flang-commits mailing list