[flang-commits] [flang] [flang] Make IsCoarray() more accurate (PR #121415)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Tue Dec 31 13:23:25 PST 2024


https://github.com/klausler created https://github.com/llvm/llvm-project/pull/121415

A designator without cosubscripts can have subscripts, component references, substrings, &c. and still have corank.  The current IsCoarray() predicate only seems to work for whole variable/component references.  This was breaking some cases of THIS_IMAGE().

>From 7ffaf7f2715a342d2109e16569b5dc3eb1e8d29d Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Tue, 31 Dec 2024 13:20:27 -0800
Subject: [PATCH] [flang] Make IsCoarray() more accurate

A designator without cosubscripts can have subscripts, component
references, substrings, &c. and still have corank.  The current
IsCoarray() predicate only seems to work for whole variable/component
references.  This was breaking some cases of THIS_IMAGE().
---
 flang/include/flang/Evaluate/tools.h    |  7 +---
 flang/include/flang/Evaluate/variable.h |  9 +++++
 flang/lib/Evaluate/variable.cpp         | 53 +++++++++++++++++++++++++
 flang/lib/Semantics/check-call.cpp      |  4 +-
 flang/test/Semantics/this_image01.f90   |  4 ++
 5 files changed, 70 insertions(+), 7 deletions(-)

diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index f586c59d46e54c..c6e8cf19b2d9f8 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -106,11 +106,8 @@ template <typename A> bool IsAssumedRank(const A *x) {
 bool IsCoarray(const ActualArgument &);
 bool IsCoarray(const Symbol &);
 template <typename A> bool IsCoarray(const A &) { return false; }
-template <typename A> bool IsCoarray(const Designator<A> &designator) {
-  if (const auto *symbol{std::get_if<SymbolRef>(&designator.u)}) {
-    return IsCoarray(**symbol);
-  }
-  return false;
+template <typename T> bool IsCoarray(const Designator<T> &designator) {
+  return designator.Corank() > 0;
 }
 template <typename T> bool IsCoarray(const Expr<T> &expr) {
   return common::visit([](const auto &x) { return IsCoarray(x); }, expr.u);
diff --git a/flang/include/flang/Evaluate/variable.h b/flang/include/flang/Evaluate/variable.h
index 9565826dbfaea4..178ce80cafded5 100644
--- a/flang/include/flang/Evaluate/variable.h
+++ b/flang/include/flang/Evaluate/variable.h
@@ -51,6 +51,7 @@ template <typename T> struct Variable;
 struct BaseObject {
   EVALUATE_UNION_CLASS_BOILERPLATE(BaseObject)
   int Rank() const;
+  int Corank() const;
   std::optional<Expr<SubscriptInteger>> LEN() const;
   llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
   const Symbol *symbol() const {
@@ -84,6 +85,7 @@ class Component {
   SymbolRef &symbol() { return symbol_; }
 
   int Rank() const;
+  int Corank() const;
   const Symbol &GetFirstSymbol() const;
   const Symbol &GetLastSymbol() const { return symbol_; }
   std::optional<Expr<SubscriptInteger>> LEN() const;
@@ -116,6 +118,7 @@ class NamedEntity {
   Component *UnwrapComponent();
 
   int Rank() const;
+  int Corank() const;
   std::optional<Expr<SubscriptInteger>> LEN() const;
   bool operator==(const NamedEntity &) const;
   llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
@@ -224,6 +227,7 @@ class ArrayRef {
   }
 
   int Rank() const;
+  int Corank() const;
   const Symbol &GetFirstSymbol() const;
   const Symbol &GetLastSymbol() const;
   std::optional<Expr<SubscriptInteger>> LEN() const;
@@ -271,6 +275,7 @@ class CoarrayRef {
   CoarrayRef &set_team(Expr<SomeInteger> &&, bool isTeamNumber = false);
 
   int Rank() const;
+  int Corank() const { return 0; }
   const Symbol &GetFirstSymbol() const;
   const Symbol &GetLastSymbol() const;
   NamedEntity GetBase() const;
@@ -294,6 +299,7 @@ class CoarrayRef {
 struct DataRef {
   EVALUATE_UNION_CLASS_BOILERPLATE(DataRef)
   int Rank() const;
+  int Corank() const;
   const Symbol &GetFirstSymbol() const;
   const Symbol &GetLastSymbol() const;
   std::optional<Expr<SubscriptInteger>> LEN() const;
@@ -331,6 +337,7 @@ class Substring {
   Parent &parent() { return parent_; }
 
   int Rank() const;
+  int Corank() const;
   template <typename A> const A *GetParentIf() const {
     return std::get_if<A>(&parent_);
   }
@@ -361,6 +368,7 @@ class ComplexPart {
   const DataRef &complex() const { return complex_; }
   Part part() const { return part_; }
   int Rank() const;
+  int Corank() const;
   const Symbol &GetFirstSymbol() const { return complex_.GetFirstSymbol(); }
   const Symbol &GetLastSymbol() const { return complex_.GetLastSymbol(); }
   bool operator==(const ComplexPart &) const;
@@ -396,6 +404,7 @@ template <typename T> class Designator {
 
   std::optional<DynamicType> GetType() const;
   int Rank() const;
+  int Corank() const;
   BaseObject GetBaseObject() const;
   const Symbol *GetLastSymbol() const;
   std::optional<Expr<SubscriptInteger>> LEN() const;
diff --git a/flang/lib/Evaluate/variable.cpp b/flang/lib/Evaluate/variable.cpp
index 707a2065ca30a7..841d0f71ed0e2f 100644
--- a/flang/lib/Evaluate/variable.cpp
+++ b/flang/lib/Evaluate/variable.cpp
@@ -465,6 +465,59 @@ template <typename T> int Designator<T>::Rank() const {
       u);
 }
 
+// Corank()
+int BaseObject::Corank() const {
+  return common::visit(common::visitors{
+                           [](SymbolRef symbol) { return symbol->Corank(); },
+                           [](const StaticDataObject::Pointer &) { return 0; },
+                       },
+      u);
+}
+
+int Component::Corank() const {
+  if (int corank{symbol_->Corank()}; corank > 0) {
+    return corank;
+  }
+  return base().Corank();
+}
+
+int NamedEntity::Corank() const {
+  return common::visit(common::visitors{
+                           [](const SymbolRef s) { return s->Corank(); },
+                           [](const Component &c) { return c.Corank(); },
+                       },
+      u_);
+}
+
+int ArrayRef::Corank() const { return base().Corank(); }
+
+int DataRef::Corank() const {
+  return common::visit(common::visitors{
+                           [](SymbolRef symbol) { return symbol->Corank(); },
+                           [](const auto &x) { return x.Corank(); },
+                       },
+      u);
+}
+
+int Substring::Corank() const {
+  return common::visit(
+      common::visitors{
+          [](const DataRef &dataRef) { return dataRef.Corank(); },
+          [](const StaticDataObject::Pointer &) { return 0; },
+      },
+      parent_);
+}
+
+int ComplexPart::Corank() const { return complex_.Corank(); }
+
+template <typename T> int Designator<T>::Corank() const {
+  return common::visit(common::visitors{
+                           [](SymbolRef symbol) { return symbol->Corank(); },
+                           [](const auto &x) { return x.Corank(); },
+                       },
+      u);
+}
+
 // GetBaseObject(), GetFirstSymbol(), GetLastSymbol(), &c.
 const Symbol &Component::GetFirstSymbol() const {
   return base_.value().GetFirstSymbol();
diff --git a/flang/lib/Semantics/check-call.cpp b/flang/lib/Semantics/check-call.cpp
index 597c280a6df8bc..95df34b4a1f3e9 100644
--- a/flang/lib/Semantics/check-call.cpp
+++ b/flang/lib/Semantics/check-call.cpp
@@ -1622,8 +1622,8 @@ static void CheckImage_Index(evaluate::ActualArguments &arguments,
             evaluate::GetShape(arguments[1]->UnwrapExpr())}) {
       if (const auto *coarrayArgSymbol{UnwrapWholeSymbolOrComponentDataRef(
               arguments[0]->UnwrapExpr())}) {
-        const auto coarrayArgCorank = coarrayArgSymbol->Corank();
-        if (const auto subArrSize = evaluate::ToInt64(*subArrShape->front())) {
+        auto coarrayArgCorank{coarrayArgSymbol->Corank()};
+        if (auto subArrSize{evaluate::ToInt64(*subArrShape->front())}) {
           if (subArrSize != coarrayArgCorank) {
             messages.Say(arguments[1]->sourceLocation(),
                 "The size of 'SUB=' (%jd) for intrinsic 'image_index' must be equal to the corank of 'COARRAY=' (%d)"_err_en_US,
diff --git a/flang/test/Semantics/this_image01.f90 b/flang/test/Semantics/this_image01.f90
index 0e59aa3fa27c6b..efe39e2fee503c 100644
--- a/flang/test/Semantics/this_image01.f90
+++ b/flang/test/Semantics/this_image01.f90
@@ -17,6 +17,10 @@ subroutine test
   print *, this_image(coarray, team)
   print *, this_image(coarray, 1)
   print *, this_image(coarray, 1, team)
+  print *, this_image(coarray(1))
+  print *, this_image(coarray(1), team)
+  print *, this_image(coarray(1), 1)
+  print *, this_image(coarray(1), 1, team)
   print *, this_image(coscalar)
   print *, this_image(coscalar, team)
   print *, this_image(coscalar, 1)



More information about the flang-commits mailing list