[flang-commits] [flang] [flang] Fix SIZEOF() expression rewriting (PR #66241)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Mon Sep 18 09:20:05 PDT 2023


https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/66241

>From 1c1a4037ed879bb8d769a6c475388e420b726998 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Wed, 6 Sep 2023 15:41:55 -0700
Subject: [PATCH] [flang] Fix SIZEOF() expression rewriting

The rewriting of the extension intrinsic function SIZEOF
was producing results that would reference symbols that were
not available in the current scope, leading to crashes in lowering.
The symbols could be function result variables, for SIZEOF(func()),
or bare derived type component names, for SIZEOF(array(n)%component).
Fixing this without regressing on a current test case involved
careful threading of some state through the TypeAndShape characterization
code and the shape/bounds analyzer, and some clean-up was done
along the way.

Pull request: https://github.com/llvm/llvm-project/pull/66241
---
 .../include/flang/Evaluate/characteristics.h  |  8 +--
 flang/include/flang/Evaluate/shape.h          | 31 +++--------
 flang/lib/Evaluate/shape.cpp                  | 55 +++++++++----------
 flang/test/Evaluate/rewrite05.f90             | 34 ++++++++++++
 4 files changed, 72 insertions(+), 56 deletions(-)
 create mode 100644 flang/test/Evaluate/rewrite05.f90

diff --git a/flang/include/flang/Evaluate/characteristics.h b/flang/include/flang/Evaluate/characteristics.h
index 20750dfad8ce06e..bcb154320364097 100644
--- a/flang/include/flang/Evaluate/characteristics.h
+++ b/flang/include/flang/Evaluate/characteristics.h
@@ -88,11 +88,11 @@ class TypeAndShape {
   static std::optional<TypeAndShape> Characterize(
       const ActualArgument &, FoldingContext &, bool invariantOnly = false);
 
-  // General case for Expr<T>, ActualArgument, &c.
+  // General case for Expr<T>, &c.
   template <typename A>
   static std::optional<TypeAndShape> Characterize(
       const A &x, FoldingContext &context, bool invariantOnly = false) {
-    if (const auto *symbol{UnwrapWholeSymbolOrComponentDataRef(x)}) {
+    if (const auto *symbol{UnwrapWholeSymbolDataRef(x)}) {
       if (auto result{Characterize(*symbol, context, invariantOnly)}) {
         return result;
       }
@@ -116,7 +116,7 @@ class TypeAndShape {
   static std::optional<TypeAndShape> Characterize(
       const Designator<Type<TypeCategory::Character, KIND>> &x,
       FoldingContext &context, bool invariantOnly = true) {
-    if (const auto *symbol{UnwrapWholeSymbolOrComponentDataRef(x)}) {
+    if (const auto *symbol{UnwrapWholeSymbolDataRef(x)}) {
       if (auto result{Characterize(*symbol, context, invariantOnly)}) {
         return result;
       }
@@ -184,8 +184,6 @@ class TypeAndShape {
   static std::optional<TypeAndShape> Characterize(
       const semantics::AssocEntityDetails &, FoldingContext &,
       bool invariantOnly = true);
-  static std::optional<TypeAndShape> Characterize(
-      const semantics::ProcEntityDetails &, FoldingContext &);
   void AcquireAttrs(const semantics::Symbol &);
   void AcquireLEN();
   void AcquireLEN(const semantics::Symbol &);
diff --git a/flang/include/flang/Evaluate/shape.h b/flang/include/flang/Evaluate/shape.h
index 5acc7f13d27da58..1294c92a01abb60 100644
--- a/flang/include/flang/Evaluate/shape.h
+++ b/flang/include/flang/Evaluate/shape.h
@@ -131,14 +131,8 @@ class GetShapeHelper
   using Result = std::optional<Shape>;
   using Base = AnyTraverse<GetShapeHelper, Result>;
   using Base::operator();
-  explicit GetShapeHelper(bool invariantOnly)
-      : Base{*this}, invariantOnly_{invariantOnly} {}
-  explicit GetShapeHelper(FoldingContext &c, bool invariantOnly)
-      : Base{*this}, context_{&c}, invariantOnly_{invariantOnly} {}
-  explicit GetShapeHelper(
-      FoldingContext &c, bool useResultSymbolShape, bool invariantOnly)
-      : Base{*this}, context_{&c}, useResultSymbolShape_{useResultSymbolShape},
-        invariantOnly_{invariantOnly} {}
+  GetShapeHelper(FoldingContext *context, bool invariantOnly)
+      : Base{*this}, context_{context}, invariantOnly_{invariantOnly} {}
 
   Result operator()(const ImpliedDoIndex &) const { return ScalarShape(); }
   Result operator()(const DescriptorInquiry &) const { return ScalarShape(); }
@@ -187,9 +181,7 @@ class GetShapeHelper
     return common::visit(
         common::visitors{
             [&](const Expr<T> &x) -> MaybeExtentExpr {
-              if (auto xShape{!useResultSymbolShape_ ? (*this)(x)
-                          : context_                 ? GetShape(*context_, x)
-                                                     : GetShape(x)}) {
+              if (auto xShape{(*this)(x)}) {
                 // Array values in array constructors get linearized.
                 return GetSize(std::move(*xShape));
               } else {
@@ -233,7 +225,7 @@ class GetShapeHelper
   void AccumulateExtent(ExtentExpr &, ExtentExpr &&) const;
 
   FoldingContext *context_{nullptr};
-  bool useResultSymbolShape_{true};
+  mutable bool useResultSymbolShape_{true};
   // When invariantOnly=false, the returned shape need not be invariant
   // in its scope; in particular, it may contain references to dummy arguments.
   bool invariantOnly_{true};
@@ -242,7 +234,7 @@ class GetShapeHelper
 template <typename A>
 std::optional<Shape> GetShape(
     FoldingContext &context, const A &x, bool invariantOnly) {
-  if (auto shape{GetShapeHelper{context, invariantOnly}(x)}) {
+  if (auto shape{GetShapeHelper{&context, invariantOnly}(x)}) {
     return Fold(context, std::move(shape));
   } else {
     return std::nullopt;
@@ -251,17 +243,13 @@ std::optional<Shape> GetShape(
 
 template <typename A>
 std::optional<Shape> GetShape(const A &x, bool invariantOnly) {
-  return GetShapeHelper{invariantOnly}(x);
+  return GetShapeHelper{/*context=*/nullptr, invariantOnly}(x);
 }
 
 template <typename A>
 std::optional<Shape> GetShape(
     FoldingContext *context, const A &x, bool invariantOnly = true) {
-  if (context) {
-    return GetShape(*context, x, invariantOnly);
-  } else {
-    return GetShapeHelper{invariantOnly}(x);
-  }
+  return GetShapeHelper{context, invariantOnly}(x);
 }
 
 template <typename A>
@@ -286,12 +274,11 @@ std::optional<ConstantSubscripts> GetConstantExtents(
 
 // Get shape that does not depends on callee scope symbols if the expression
 // contains calls. Return std::nullopt if it is not possible to build such shape
-// (e.g. for calls to array functions whose result shape depends on the
+// (e.g. for calls to array-valued functions whose result shape depends on the
 // arguments).
 template <typename A>
 std::optional<Shape> GetContextFreeShape(FoldingContext &context, const A &x) {
-  return GetShapeHelper{
-      context, /*useResultSymbolShape=*/false, /*invariantOnly=*/true}(x);
+  return GetShapeHelper{&context, /*invariantOnly=*/true}(x);
 }
 
 // Compilation-time shape conformance checking, when corresponding extents
diff --git a/flang/lib/Evaluate/shape.cpp b/flang/lib/Evaluate/shape.cpp
index e26479cc1f055fb..c356af71021f3f4 100644
--- a/flang/lib/Evaluate/shape.cpp
+++ b/flang/lib/Evaluate/shape.cpp
@@ -323,10 +323,10 @@ class GetLowerBoundHelper
         if (IsActuallyConstant(exprLowerBound)) {
           return std::move(exprLowerBound);
         } else {
-          // If the lower bound of the associated entity is not resolved to
+          // If the lower bound of the associated entity is not resolved to a
           // constant expression at the time of the association, it is unsafe
           // to re-evaluate it later in the associate construct. Statements
-          // in-between may have modified its operands value.
+          // in between may have modified its operands value.
           return ExtentExpr{DescriptorInquiry{std::move(base),
               DescriptorInquiry::Field::LowerBound, dimension_}};
         }
@@ -476,24 +476,23 @@ static MaybeExtentExpr GetNonNegativeExtent(
   }
 }
 
-MaybeExtentExpr GetAssociatedExtent(const NamedEntity &base,
-    const semantics::AssocEntityDetails &assoc, int dimension) {
-  if (auto shape{GetShape(assoc.expr())}) {
-    if (dimension < static_cast<int>(shape->size())) {
-      auto &extent{shape->at(dimension)};
-      if (extent && IsActuallyConstant(*extent)) {
+static MaybeExtentExpr GetAssociatedExtent(
+    const Symbol &symbol, int dimension) {
+  if (const auto *assoc{symbol.detailsIf<semantics::AssocEntityDetails>()};
+      assoc && !assoc->rank()) { // not SELECT RANK case
+    if (auto shape{GetShape(assoc->expr())};
+        shape && dimension < static_cast<int>(shape->size())) {
+      if (auto &extent{shape->at(dimension)};
+          // Don't return a non-constant extent, as the variables that
+          // determine the shape of the selector's expression may change
+          // during execution of the construct.
+          extent && IsActuallyConstant(*extent)) {
         return std::move(extent);
-      } else {
-        // Otherwise, evaluating the associated expression extent expression
-        // after the associate statement is unsafe given statements inside the
-        // associate may have modified the associated expression operands
-        // values.
-        return ExtentExpr{DescriptorInquiry{
-            NamedEntity{base}, DescriptorInquiry::Field::Extent, dimension}};
       }
     }
   }
-  return std::nullopt;
+  return ExtentExpr{DescriptorInquiry{
+      NamedEntity{symbol}, DescriptorInquiry::Field::Extent, dimension}};
 }
 
 MaybeExtentExpr GetExtent(
@@ -508,14 +507,16 @@ MaybeExtentExpr GetExtent(
       if (semantics::IsDescriptor(symbol) && dimension < *assoc->rank()) {
         return ExtentExpr{DescriptorInquiry{
             NamedEntity{base}, DescriptorInquiry::Field::Extent, dimension}};
+      } else {
+        return std::nullopt;
       }
     } else {
-      return GetAssociatedExtent(base, *assoc, dimension);
+      return GetAssociatedExtent(last, dimension);
     }
   }
   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
     if (IsImpliedShape(symbol) && details->init()) {
-      if (auto shape{GetShape(symbol)}) {
+      if (auto shape{GetShape(symbol, invariantOnly)}) {
         if (dimension < static_cast<int>(shape->size())) {
           return std::move(shape->at(dimension));
         }
@@ -527,7 +528,7 @@ MaybeExtentExpr GetExtent(
           if (auto extent{GetNonNegativeExtent(shapeSpec, invariantOnly)}) {
             return extent;
           } else if (details->IsAssumedSize() && j == symbol.Rank()) {
-            return std::nullopt;
+            break;
           } else if (semantics::IsDescriptor(symbol)) {
             return ExtentExpr{DescriptorInquiry{NamedEntity{base},
                 DescriptorInquiry::Field::Extent, dimension}};
@@ -620,7 +621,7 @@ MaybeExtentExpr GetRawUpperBound(
       return std::nullopt;
     } else if (assoc->rank() && dimension >= *assoc->rank()) {
       return std::nullopt;
-    } else if (auto extent{GetAssociatedExtent(base, *assoc, dimension)}) {
+    } else if (auto extent{GetAssociatedExtent(symbol, dimension)}) {
       return ComputeUpperBound(
           GetRawLowerBound(base, dimension), std::move(extent));
     }
@@ -680,11 +681,9 @@ static MaybeExtentExpr GetUBOUND(FoldingContext *context,
             std::move(base), DescriptorInquiry::Field::Extent, dimension}};
         return ComputeUpperBound(std::move(lb), std::move(extent));
       }
-    } else if (assoc->expr()) {
-      if (auto extent{GetAssociatedExtent(base, *assoc, dimension)}) {
-        if (auto lb{GetLBOUND(base, dimension, invariantOnly)}) {
-          return ComputeUpperBound(std::move(*lb), std::move(extent));
-        }
+    } else if (auto extent{GetAssociatedExtent(symbol, dimension)}) {
+      if (auto lb{GetLBOUND(base, dimension, invariantOnly)}) {
+        return ComputeUpperBound(std::move(*lb), std::move(extent));
       }
     }
   }
@@ -768,7 +767,7 @@ auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result {
               auto resultShape{(*this)(subp.result())};
               if (resultShape && !useResultSymbolShape_) {
                 // Ensure the shape is constant. Otherwise, it may be referring
-                // to symbols that belong to the subroutine scope and are
+                // to symbols that belong to the function's scope and are
                 // meaningless on the caller side without the related call
                 // expression.
                 for (auto &extent : *resultShape) {
@@ -799,9 +798,6 @@ auto GetShapeHelper::operator()(const Component &component) const -> Result {
   } else if (symbol.has<semantics::ObjectEntityDetails>()) {
     NamedEntity base{Component{component}};
     return CreateShape(rank, base);
-  } else if (symbol.has<semantics::AssocEntityDetails>()) {
-    NamedEntity base{Component{component}};
-    return Result{CreateShape(rank, base)};
   } else {
     return (*this)(symbol);
   }
@@ -878,6 +874,7 @@ auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
     }
     return ScalarShape();
   } else if (const Symbol * symbol{call.proc().GetSymbol()}) {
+    auto restorer{common::ScopedSet(useResultSymbolShape_, false)};
     return (*this)(*symbol);
   } else if (const auto *intrinsic{call.proc().GetSpecificIntrinsic()}) {
     if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
diff --git a/flang/test/Evaluate/rewrite05.f90 b/flang/test/Evaluate/rewrite05.f90
new file mode 100644
index 000000000000000..f81974f24fd9717
--- /dev/null
+++ b/flang/test/Evaluate/rewrite05.f90
@@ -0,0 +1,34 @@
+! RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
+program main
+  type t
+    integer, allocatable :: component(:)
+  end type
+  type(t) :: x
+  call init(10)
+  !CHECK: PRINT *, [INTEGER(4)::int(lbound(x%component,dim=1,kind=8),kind=4)]
+  print *, lbound(x%component)
+  !CHECK: PRINT *, [INTEGER(4)::int(size(x%component,dim=1,kind=8)+lbound(x%component,dim=1,kind=8)-1_8,kind=4)]
+  print *, ubound(x%component)
+  !CHECK: PRINT *, int(size(x%component,dim=1,kind=8),kind=4)
+  print *, size(x%component)
+  !CHECK: PRINT *, 4_8*size(x%component,dim=1,kind=8)
+  print *, sizeof(x%component)
+  !CHECK: PRINT *, 1_4
+  print *, lbound(iota(10), 1)
+  !CHECK: PRINT *, ubound(iota(10_4),1_4)
+  print *, ubound(iota(10), 1)
+  !CHECK: PRINT *, size(iota(10_4))
+  print *, size(iota(10))
+  !CHECK: PRINT *, sizeof(iota(10_4))
+  print *, sizeof(iota(10))
+ contains
+  function iota(n) result(result)
+    integer, intent(in) :: n
+    integer, allocatable :: result(:)
+    result = [(j,j=1,n)]
+  end
+  subroutine init(n)
+    integer, intent(in) :: n
+    allocate(x%component(0:n-1))
+  end
+end



More information about the flang-commits mailing list