[flang-commits] [flang] [flang] Make IsCoarray() more accurate (PR #121415)
via flang-commits
flang-commits at lists.llvm.org
Tue Dec 31 13:23:59 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-semantics
Author: Peter Klausler (klausler)
<details>
<summary>Changes</summary>
A designator without cosubscripts can have subscripts, component references, substrings, &c. and still have corank. The current IsCoarray() predicate only seems to work for whole variable/component references. This was breaking some cases of THIS_IMAGE().
---
Full diff: https://github.com/llvm/llvm-project/pull/121415.diff
5 Files Affected:
- (modified) flang/include/flang/Evaluate/tools.h (+2-5)
- (modified) flang/include/flang/Evaluate/variable.h (+9)
- (modified) flang/lib/Evaluate/variable.cpp (+53)
- (modified) flang/lib/Semantics/check-call.cpp (+2-2)
- (modified) flang/test/Semantics/this_image01.f90 (+4)
``````````diff
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index f586c59d46e54c..c6e8cf19b2d9f8 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -106,11 +106,8 @@ template <typename A> bool IsAssumedRank(const A *x) {
bool IsCoarray(const ActualArgument &);
bool IsCoarray(const Symbol &);
template <typename A> bool IsCoarray(const A &) { return false; }
-template <typename A> bool IsCoarray(const Designator<A> &designator) {
- if (const auto *symbol{std::get_if<SymbolRef>(&designator.u)}) {
- return IsCoarray(**symbol);
- }
- return false;
+template <typename T> bool IsCoarray(const Designator<T> &designator) {
+ return designator.Corank() > 0;
}
template <typename T> bool IsCoarray(const Expr<T> &expr) {
return common::visit([](const auto &x) { return IsCoarray(x); }, expr.u);
diff --git a/flang/include/flang/Evaluate/variable.h b/flang/include/flang/Evaluate/variable.h
index 9565826dbfaea4..178ce80cafded5 100644
--- a/flang/include/flang/Evaluate/variable.h
+++ b/flang/include/flang/Evaluate/variable.h
@@ -51,6 +51,7 @@ template <typename T> struct Variable;
struct BaseObject {
EVALUATE_UNION_CLASS_BOILERPLATE(BaseObject)
int Rank() const;
+ int Corank() const;
std::optional<Expr<SubscriptInteger>> LEN() const;
llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
const Symbol *symbol() const {
@@ -84,6 +85,7 @@ class Component {
SymbolRef &symbol() { return symbol_; }
int Rank() const;
+ int Corank() const;
const Symbol &GetFirstSymbol() const;
const Symbol &GetLastSymbol() const { return symbol_; }
std::optional<Expr<SubscriptInteger>> LEN() const;
@@ -116,6 +118,7 @@ class NamedEntity {
Component *UnwrapComponent();
int Rank() const;
+ int Corank() const;
std::optional<Expr<SubscriptInteger>> LEN() const;
bool operator==(const NamedEntity &) const;
llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
@@ -224,6 +227,7 @@ class ArrayRef {
}
int Rank() const;
+ int Corank() const;
const Symbol &GetFirstSymbol() const;
const Symbol &GetLastSymbol() const;
std::optional<Expr<SubscriptInteger>> LEN() const;
@@ -271,6 +275,7 @@ class CoarrayRef {
CoarrayRef &set_team(Expr<SomeInteger> &&, bool isTeamNumber = false);
int Rank() const;
+ int Corank() const { return 0; }
const Symbol &GetFirstSymbol() const;
const Symbol &GetLastSymbol() const;
NamedEntity GetBase() const;
@@ -294,6 +299,7 @@ class CoarrayRef {
struct DataRef {
EVALUATE_UNION_CLASS_BOILERPLATE(DataRef)
int Rank() const;
+ int Corank() const;
const Symbol &GetFirstSymbol() const;
const Symbol &GetLastSymbol() const;
std::optional<Expr<SubscriptInteger>> LEN() const;
@@ -331,6 +337,7 @@ class Substring {
Parent &parent() { return parent_; }
int Rank() const;
+ int Corank() const;
template <typename A> const A *GetParentIf() const {
return std::get_if<A>(&parent_);
}
@@ -361,6 +368,7 @@ class ComplexPart {
const DataRef &complex() const { return complex_; }
Part part() const { return part_; }
int Rank() const;
+ int Corank() const;
const Symbol &GetFirstSymbol() const { return complex_.GetFirstSymbol(); }
const Symbol &GetLastSymbol() const { return complex_.GetLastSymbol(); }
bool operator==(const ComplexPart &) const;
@@ -396,6 +404,7 @@ template <typename T> class Designator {
std::optional<DynamicType> GetType() const;
int Rank() const;
+ int Corank() const;
BaseObject GetBaseObject() const;
const Symbol *GetLastSymbol() const;
std::optional<Expr<SubscriptInteger>> LEN() const;
diff --git a/flang/lib/Evaluate/variable.cpp b/flang/lib/Evaluate/variable.cpp
index 707a2065ca30a7..841d0f71ed0e2f 100644
--- a/flang/lib/Evaluate/variable.cpp
+++ b/flang/lib/Evaluate/variable.cpp
@@ -465,6 +465,59 @@ template <typename T> int Designator<T>::Rank() const {
u);
}
+// Corank()
+int BaseObject::Corank() const {
+ return common::visit(common::visitors{
+ [](SymbolRef symbol) { return symbol->Corank(); },
+ [](const StaticDataObject::Pointer &) { return 0; },
+ },
+ u);
+}
+
+int Component::Corank() const {
+ if (int corank{symbol_->Corank()}; corank > 0) {
+ return corank;
+ }
+ return base().Corank();
+}
+
+int NamedEntity::Corank() const {
+ return common::visit(common::visitors{
+ [](const SymbolRef s) { return s->Corank(); },
+ [](const Component &c) { return c.Corank(); },
+ },
+ u_);
+}
+
+int ArrayRef::Corank() const { return base().Corank(); }
+
+int DataRef::Corank() const {
+ return common::visit(common::visitors{
+ [](SymbolRef symbol) { return symbol->Corank(); },
+ [](const auto &x) { return x.Corank(); },
+ },
+ u);
+}
+
+int Substring::Corank() const {
+ return common::visit(
+ common::visitors{
+ [](const DataRef &dataRef) { return dataRef.Corank(); },
+ [](const StaticDataObject::Pointer &) { return 0; },
+ },
+ parent_);
+}
+
+int ComplexPart::Corank() const { return complex_.Corank(); }
+
+template <typename T> int Designator<T>::Corank() const {
+ return common::visit(common::visitors{
+ [](SymbolRef symbol) { return symbol->Corank(); },
+ [](const auto &x) { return x.Corank(); },
+ },
+ u);
+}
+
// GetBaseObject(), GetFirstSymbol(), GetLastSymbol(), &c.
const Symbol &Component::GetFirstSymbol() const {
return base_.value().GetFirstSymbol();
diff --git a/flang/lib/Semantics/check-call.cpp b/flang/lib/Semantics/check-call.cpp
index 597c280a6df8bc..95df34b4a1f3e9 100644
--- a/flang/lib/Semantics/check-call.cpp
+++ b/flang/lib/Semantics/check-call.cpp
@@ -1622,8 +1622,8 @@ static void CheckImage_Index(evaluate::ActualArguments &arguments,
evaluate::GetShape(arguments[1]->UnwrapExpr())}) {
if (const auto *coarrayArgSymbol{UnwrapWholeSymbolOrComponentDataRef(
arguments[0]->UnwrapExpr())}) {
- const auto coarrayArgCorank = coarrayArgSymbol->Corank();
- if (const auto subArrSize = evaluate::ToInt64(*subArrShape->front())) {
+ auto coarrayArgCorank{coarrayArgSymbol->Corank()};
+ if (auto subArrSize{evaluate::ToInt64(*subArrShape->front())}) {
if (subArrSize != coarrayArgCorank) {
messages.Say(arguments[1]->sourceLocation(),
"The size of 'SUB=' (%jd) for intrinsic 'image_index' must be equal to the corank of 'COARRAY=' (%d)"_err_en_US,
diff --git a/flang/test/Semantics/this_image01.f90 b/flang/test/Semantics/this_image01.f90
index 0e59aa3fa27c6b..efe39e2fee503c 100644
--- a/flang/test/Semantics/this_image01.f90
+++ b/flang/test/Semantics/this_image01.f90
@@ -17,6 +17,10 @@ subroutine test
print *, this_image(coarray, team)
print *, this_image(coarray, 1)
print *, this_image(coarray, 1, team)
+ print *, this_image(coarray(1))
+ print *, this_image(coarray(1), team)
+ print *, this_image(coarray(1), 1)
+ print *, this_image(coarray(1), 1, team)
print *, this_image(coscalar)
print *, this_image(coscalar, team)
print *, this_image(coscalar, 1)
``````````
</details>
https://github.com/llvm/llvm-project/pull/121415
More information about the flang-commits
mailing list