[flang-commits] [flang] c4a7395 - [flang] Limit shape inquiries rewrite to associate construct entity

Jean Perier via flang-commits flang-commits at lists.llvm.org
Wed Sep 21 01:28:07 PDT 2022


Author: Jean Perier
Date: 2022-09-21T10:25:06+02:00
New Revision: c4a73957f6c59f4f88d186c1a98327a826fbfb16

URL: https://github.com/llvm/llvm-project/commit/c4a73957f6c59f4f88d186c1a98327a826fbfb16
DIFF: https://github.com/llvm/llvm-project/commit/c4a73957f6c59f4f88d186c1a98327a826fbfb16.diff

LOG: [flang] Limit shape inquiries rewrite to associate construct entity

The previous code was rewriting all shape inquires on associate
construct entities to inquires on the associated expression or variable.

This is is incorrect because at the point of inquiry, some statement
between the association and the inquiry may have modified the expression
operands or variable in a way that changes its shapes or bounds.

For instance, in the example below, expression rewrites was previously
replacing `size(x, 1)` by `size(p, 1)` which is invalid if p is a
pointer.

```
associate(x => p + 1)
 call call_that_may_modify_p_shape()
 print *, size(x, 1)
end associate
```

This change restricts rewrites of shape inquiries on associate construct entity
to use the associated expression shape and bounds if and only if the
shape/bounds are compile time constant. Otherwise, this may be invalid.

Differential Revision: https://reviews.llvm.org/D133857

Added: 
    

Modified: 
    flang/include/flang/Evaluate/check-expression.h
    flang/lib/Evaluate/check-expression.cpp
    flang/lib/Evaluate/shape.cpp
    flang/test/Evaluate/rewrite01.f90

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Evaluate/check-expression.h b/flang/include/flang/Evaluate/check-expression.h
index a53c8dfb57e41..810a8c9c7f7dc 100644
--- a/flang/include/flang/Evaluate/check-expression.h
+++ b/flang/include/flang/Evaluate/check-expression.h
@@ -50,6 +50,8 @@ template <typename A> bool IsActuallyConstant(const A &);
 extern template bool IsActuallyConstant(const Expr<SomeType> &);
 extern template bool IsActuallyConstant(const Expr<SomeInteger> &);
 extern template bool IsActuallyConstant(const Expr<SubscriptInteger> &);
+extern template bool IsActuallyConstant(
+    const std::optional<Expr<SubscriptInteger>> &);
 
 // Checks whether an expression is an object designator with
 // constant addressing and no vector-valued subscript.

diff  --git a/flang/lib/Evaluate/check-expression.cpp b/flang/lib/Evaluate/check-expression.cpp
index f2486ef9666a8..814785c79365a 100644
--- a/flang/lib/Evaluate/check-expression.cpp
+++ b/flang/lib/Evaluate/check-expression.cpp
@@ -190,6 +190,7 @@ template <typename A> bool IsActuallyConstant(const A &x) {
 template bool IsActuallyConstant(const Expr<SomeType> &);
 template bool IsActuallyConstant(const Expr<SomeInteger> &);
 template bool IsActuallyConstant(const Expr<SubscriptInteger> &);
+template bool IsActuallyConstant(const std::optional<Expr<SubscriptInteger>> &);
 
 // Object pointer initialization checking predicate IsInitialDataTarget().
 // This code determines whether an expression is allowable as the static

diff  --git a/flang/lib/Evaluate/shape.cpp b/flang/lib/Evaluate/shape.cpp
index a2748feaabd15..72c84cba86b13 100644
--- a/flang/lib/Evaluate/shape.cpp
+++ b/flang/lib/Evaluate/shape.cpp
@@ -314,7 +314,17 @@ class GetLowerBoundHelper
               DescriptorInquiry::Field::LowerBound, dimension_}};
         }
       } else {
-        return (*this)(assoc->expr());
+        auto exprLowerBound{((*this)(assoc->expr()))};
+        if (IsActuallyConstant(exprLowerBound)) {
+          return std::move(exprLowerBound);
+        } else {
+          // If the lower bound of the associated entity is not resolved to
+          // constant expression at the time of the association, it is unsafe
+          // to re-evaluate it later in the associate construct. Statements
+          // in-between may have modified its operands value.
+          return ExtentExpr{DescriptorInquiry{std::move(base),
+              DescriptorInquiry::Field::LowerBound, dimension_}};
+        }
       }
     }
     if constexpr (LBOUND_SEMANTICS) {
@@ -429,6 +439,26 @@ static MaybeExtentExpr GetNonNegativeExtent(
   }
 }
 
+MaybeExtentExpr GetAssociatedExtent(const NamedEntity &base,
+    const semantics::AssocEntityDetails &assoc, int dimension) {
+  if (auto shape{GetShape(assoc.expr())}) {
+    if (dimension < static_cast<int>(shape->size())) {
+      auto &extent{shape->at(dimension)};
+      if (extent && IsActuallyConstant(*extent)) {
+        return std::move(extent);
+      } else {
+        // Otherwise, evaluating the associated expression extent expression
+        // after the associate statement is unsafe given statements inside the
+        // associate may have modified the associated expression operands
+        // values.
+        return ExtentExpr{DescriptorInquiry{
+            NamedEntity{base}, DescriptorInquiry::Field::Extent, dimension}};
+      }
+    }
+  }
+  return std::nullopt;
+}
+
 MaybeExtentExpr GetExtent(const NamedEntity &base, int dimension) {
   CHECK(dimension >= 0);
   const Symbol &last{base.GetLastSymbol()};
@@ -439,10 +469,8 @@ MaybeExtentExpr GetExtent(const NamedEntity &base, int dimension) {
         return ExtentExpr{DescriptorInquiry{
             NamedEntity{base}, DescriptorInquiry::Field::Extent, dimension}};
       }
-    } else if (auto shape{GetShape(assoc->expr())}) {
-      if (dimension < static_cast<int>(shape->size())) {
-        return std::move(shape->at(dimension));
-      }
+    } else {
+      return GetAssociatedExtent(base, *assoc, dimension);
     }
   }
   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
@@ -547,11 +575,9 @@ MaybeExtentExpr GetRawUpperBound(const NamedEntity &base, int dimension) {
     }
   } else if (const auto *assoc{
                  symbol.detailsIf<semantics::AssocEntityDetails>()}) {
-    if (auto shape{GetShape(assoc->expr())}) {
-      if (dimension < static_cast<int>(shape->size())) {
-        return ComputeUpperBound(
-            GetRawLowerBound(base, dimension), std::move(shape->at(dimension)));
-      }
+    if (auto extent{GetAssociatedExtent(base, *assoc, dimension)}) {
+      return ComputeUpperBound(
+          GetRawLowerBound(base, dimension), std::move(extent));
     }
   }
   return std::nullopt;
@@ -597,12 +623,9 @@ static MaybeExtentExpr GetUBOUND(
     }
   } else if (const auto *assoc{
                  symbol.detailsIf<semantics::AssocEntityDetails>()}) {
-    if (auto shape{GetShape(assoc->expr())}) {
-      if (dimension < static_cast<int>(shape->size())) {
-        if (auto lb{GetLBOUND(base, dimension)}) {
-          return ComputeUpperBound(
-              std::move(*lb), std::move(shape->at(dimension)));
-        }
+    if (auto extent{GetAssociatedExtent(base, *assoc, dimension)}) {
+      if (auto lb{GetLBOUND(base, dimension)}) {
+        return ComputeUpperBound(std::move(*lb), std::move(extent));
       }
     }
   }
@@ -674,12 +697,22 @@ auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result {
             }
           },
           [&](const semantics::AssocEntityDetails &assoc) {
+            NamedEntity base{symbol};
             if (assoc.rank()) { // SELECT RANK case
               int n{assoc.rank().value()};
-              NamedEntity base{symbol};
               return Result{CreateShape(n, base)};
             } else {
-              return (*this)(assoc.expr());
+              auto exprShape{((*this)(assoc.expr()))};
+              if (exprShape) {
+                int rank{static_cast<int>(exprShape->size())};
+                for (int dimension{0}; dimension < rank; ++dimension) {
+                  auto &extent{(*exprShape)[dimension]};
+                  if (extent && !IsActuallyConstant(*extent)) {
+                    extent = GetExtent(base, dimension);
+                  }
+                }
+              }
+              return exprShape;
             }
           },
           [&](const semantics::SubprogramDetails &subp) -> Result {

diff  --git a/flang/test/Evaluate/rewrite01.f90 b/flang/test/Evaluate/rewrite01.f90
index 0d781094b3d56..37fd11e7e8e8a 100644
--- a/flang/test/Evaluate/rewrite01.f90
+++ b/flang/test/Evaluate/rewrite01.f90
@@ -157,4 +157,33 @@ end function mofun
   print *, len(mofun(m+1))
 end subroutine len_test
 
+!CHECK-LABEL: associate_tests
+subroutine associate_tests(p)
+  real, pointer :: p(:)
+  real :: a(10:20)
+  interface
+    subroutine may_change_p_bounds(p)
+      real, pointer :: p(:)
+    end subroutine
+  end interface
+  associate(x => p)
+    call may_change_p_bounds(p)
+    !CHECK: PRINT *, lbound(x,dim=1,kind=8), size(x,dim=1,kind=8)+lbound(x,dim=1,kind=8)-1_8, size(x,dim=1,kind=8)
+    print *, lbound(x, 1, kind=8), ubound(x, 1, kind=8), size(x, 1, kind=8)
+  end associate
+  associate(x => p+1)
+    call may_change_p_bounds(p)
+    !CHECK: PRINT *, 1_8, size(x,dim=1,kind=8), size(x,dim=1,kind=8)
+    print *, lbound(x, 1, kind=8), ubound(x, 1, kind=8), size(x, 1, kind=8)
+  end associate
+  associate(x => a)
+    !CHECK: PRINT *, 10_8, 20_8, 11_8
+    print *, lbound(x, 1, kind=8), ubound(x, 1, kind=8), size(x, 1, kind=8)
+  end associate
+  associate(x => a+42.)
+    !CHECK: PRINT *, 1_8, 11_8, 11_8
+    print *, lbound(x, 1, kind=8), ubound(x, 1, kind=8), size(x, 1, kind=8)
+  end associate
+end subroutine
+
 end module


        


More information about the flang-commits mailing list