[flang-commits] [flang] b34f116 - [flang] Fix assert on constant folding of extended types

Peter Steinfeld via flang-commits flang-commits at lists.llvm.org
Thu Sep 10 14:38:43 PDT 2020


Author: Peter Steinfeld
Date: 2020-09-10T14:34:03-07:00
New Revision: b34f116856306d97aa9244a46eb1643a8ddd49a8

URL: https://github.com/llvm/llvm-project/commit/b34f116856306d97aa9244a46eb1643a8ddd49a8
DIFF: https://github.com/llvm/llvm-project/commit/b34f116856306d97aa9244a46eb1643a8ddd49a8.diff

LOG: [flang] Fix assert on constant folding of extended types

When we define a derived type that extends another derived type, we can then
create a structure constructor that contains values for the fields of both the
child type and its parent.  The compiler's internal representation of that
value contains the name of the parent type where a component name would
normally appear.  This caused an assert during contant folding.

There are three cases for components that appear in structure constructors.
The first is the normal case of a component appearing in a structure
constructor for its type.

  The second is a component of the parent (or grandparent) type appearing in a
  structure constructor for the child type.

  The third is the parent type component, which can appear in the structure
  constructor of its child.

There are also cases where the component can be arrays.

I created the test case folding12.f90 that covers all of these cases and
modified the code to handle them.

Most of my changes were to the "Find()" method of the type
"StructureConstructor" where I added code to cover the second and third cases
described above.  To handle these cases, I needed to create a
"StructureConstructor" for the parent type component and return it.  To handle
returning a newly created "StructureConstructor", I changed the return type of
"Find()" to be "std::optional" rather than an ordinary pointer.

This change supersedes D86172.

Differential Revision: https://reviews.llvm.org/D87151

Added: 
    flang/test/Evaluate/folding12.f90

Modified: 
    flang/include/flang/Evaluate/expression.h
    flang/include/flang/Evaluate/type.h
    flang/lib/Evaluate/expression.cpp
    flang/lib/Evaluate/fold-implementation.h
    flang/lib/Evaluate/type.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Evaluate/expression.h b/flang/include/flang/Evaluate/expression.h
index 09847ec95407..f0ce375da015 100644
--- a/flang/include/flang/Evaluate/expression.h
+++ b/flang/include/flang/Evaluate/expression.h
@@ -717,7 +717,8 @@ class StructureConstructor {
     return values_.end();
   }
 
-  const Expr<SomeType> *Find(const Symbol &) const; // can return null
+  // can return nullopt
+  std::optional<Expr<SomeType>> Find(const Symbol &) const;
 
   StructureConstructor &Add(const semantics::Symbol &, Expr<SomeType> &&);
   int Rank() const { return 0; }
@@ -725,6 +726,7 @@ class StructureConstructor {
   llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
 
 private:
+  std::optional<Expr<SomeType>> CreateParentComponent(const Symbol &) const;
   Result result_;
   StructureConstructorValues values_;
 };

diff  --git a/flang/include/flang/Evaluate/type.h b/flang/include/flang/Evaluate/type.h
index cf13ba6e27d9..663ece6eb4a0 100644
--- a/flang/include/flang/Evaluate/type.h
+++ b/flang/include/flang/Evaluate/type.h
@@ -217,6 +217,8 @@ class DynamicType {
 const semantics::DerivedTypeSpec *GetDerivedTypeSpec(const DynamicType &);
 const semantics::DerivedTypeSpec *GetDerivedTypeSpec(
     const std::optional<DynamicType> &);
+const semantics::DerivedTypeSpec *GetParentTypeSpec(
+    const semantics::DerivedTypeSpec &);
 
 std::string DerivedTypeSpecAsFortran(const semantics::DerivedTypeSpec &);
 

diff  --git a/flang/lib/Evaluate/expression.cpp b/flang/lib/Evaluate/expression.cpp
index 5a456648b825..7f8c9eb32f3f 100644
--- a/flang/lib/Evaluate/expression.cpp
+++ b/flang/lib/Evaluate/expression.cpp
@@ -12,7 +12,12 @@
 #include "flang/Evaluate/common.h"
 #include "flang/Evaluate/tools.h"
 #include "flang/Evaluate/variable.h"
+#include "flang/Parser/char-block.h"
 #include "flang/Parser/message.h"
+#include "flang/Semantics/scope.h"
+#include "flang/Semantics/symbol.h"
+#include "flang/Semantics/tools.h"
+#include "flang/Semantics/type.h"
 #include "llvm/Support/raw_ostream.h"
 #include <string>
 #include <type_traits>
@@ -206,13 +211,75 @@ bool Expr<SomeType>::operator==(const Expr<SomeType> &that) const {
 
 DynamicType StructureConstructor::GetType() const { return result_.GetType(); }
 
-const Expr<SomeType> *StructureConstructor::Find(
+std::optional<Expr<SomeType>> StructureConstructor::CreateParentComponent(
+    const Symbol &component) const {
+  if (const semantics::DerivedTypeSpec *
+      parentSpec{GetParentTypeSpec(derivedTypeSpec())}) {
+    StructureConstructor structureConstructor{*parentSpec};
+    if (const auto *parentDetails{
+            component.detailsIf<semantics::DerivedTypeDetails>()}) {
+      auto parentIter{parentDetails->componentNames().begin()};
+      for (const auto &childIter : values_) {
+        if (parentIter == parentDetails->componentNames().end()) {
+          break; // There are more components in the child
+        }
+        SymbolRef componentSymbol{childIter.first};
+        structureConstructor.Add(
+            *componentSymbol, common::Clone(childIter.second.value()));
+        ++parentIter;
+      }
+      Constant<SomeDerived> constResult{std::move(structureConstructor)};
+      Expr<SomeDerived> result{std::move(constResult)};
+      return std::optional<Expr<SomeType>>{result};
+    }
+  }
+  return std::nullopt;
+}
+
+static const Symbol *GetParentComponentSymbol(const Symbol &symbol) {
+  if (symbol.test(Symbol::Flag::ParentComp)) {
+    // we have a created parent component
+    const auto &compObject{symbol.get<semantics::ObjectEntityDetails>()};
+    if (const semantics::DeclTypeSpec * compType{compObject.type()}) {
+      const semantics::DerivedTypeSpec &dtSpec{compType->derivedTypeSpec()};
+      const semantics::Symbol &compTypeSymbol{dtSpec.typeSymbol()};
+      return &compTypeSymbol;
+    }
+  }
+  if (symbol.detailsIf<semantics::DerivedTypeDetails>()) {
+    // we have an implicit parent type component
+    return &symbol;
+  }
+  return nullptr;
+}
+
+std::optional<Expr<SomeType>> StructureConstructor::Find(
     const Symbol &component) const {
   if (auto iter{values_.find(component)}; iter != values_.end()) {
-    return &iter->second.value();
-  } else {
-    return nullptr;
+    return iter->second.value();
+  }
+  // The component wasn't there directly, see if we're looking for the parent
+  // component of an extended type
+  if (const Symbol * typeSymbol{GetParentComponentSymbol(component)}) {
+    return CreateParentComponent(*typeSymbol);
+  }
+  // Look for the component in the parent type component.  The parent type
+  // component is always the first one
+  if (!values_.empty()) {
+    const Expr<SomeType> *parentExpr{&values_.begin()->second.value()};
+    if (const Expr<SomeDerived> *derivedExpr{
+            std::get_if<Expr<SomeDerived>>(&parentExpr->u)}) {
+      if (const Constant<SomeDerived> *constExpr{
+              std::get_if<Constant<SomeDerived>>(&derivedExpr->u)}) {
+        if (std::optional<StructureConstructor> parentComponentValue{
+                constExpr->GetScalarValue()}) {
+          // Try to find the component in the parent structure constructor
+          return parentComponentValue->Find(component);
+        }
+      }
+    }
   }
+  return std::nullopt;
 }
 
 StructureConstructor &StructureConstructor::Add(

diff  --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index e01c7de72f8d..bb5463e697fe 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -296,8 +296,8 @@ std::optional<Constant<T>> Folder<T>::ApplyComponent(
     Constant<SomeDerived> &&structures, const Symbol &component,
     const std::vector<Constant<SubscriptInteger>> *subscripts) {
   if (auto scalar{structures.GetScalarValue()}) {
-    if (auto *expr{scalar->Find(component)}) {
-      if (const Constant<T> *value{UnwrapConstantValue<T>(*expr)}) {
+    if (std::optional<Expr<SomeType>> expr{scalar->Find(component)}) {
+      if (const Constant<T> *value{UnwrapConstantValue<T>(expr.value())}) {
         if (!subscripts) {
           return std::move(*value);
         } else {
@@ -314,12 +314,12 @@ std::optional<Constant<T>> Folder<T>::ApplyComponent(
     ConstantSubscripts at{structures.lbounds()};
     do {
       StructureConstructor scalar{structures.At(at)};
-      if (auto *expr{scalar.Find(component)}) {
-        if (const Constant<T> *value{UnwrapConstantValue<T>(*expr)}) {
+      if (std::optional<Expr<SomeType>> expr{scalar.Find(component)}) {
+        if (const Constant<T> *value{UnwrapConstantValue<T>(expr.value())}) {
           if (!array.get()) {
             // This technique ensures that character length or derived type
             // information is propagated to the array constructor.
-            auto *typedExpr{UnwrapExpr<Expr<T>>(*expr)};
+            auto *typedExpr{UnwrapExpr<Expr<T>>(expr.value())};
             CHECK(typedExpr);
             array = std::make_unique<ArrayConstructor<T>>(*typedExpr);
           }

diff  --git a/flang/lib/Evaluate/type.cpp b/flang/lib/Evaluate/type.cpp
index e1eec19e896b..e96e19150f4e 100644
--- a/flang/lib/Evaluate/type.cpp
+++ b/flang/lib/Evaluate/type.cpp
@@ -207,7 +207,7 @@ static const semantics::Symbol *FindParentComponent(
   return nullptr;
 }
 
-static const semantics::DerivedTypeSpec *GetParentTypeSpec(
+const semantics::DerivedTypeSpec *GetParentTypeSpec(
     const semantics::DerivedTypeSpec &derived) {
   if (const semantics::Symbol * parent{FindParentComponent(derived)}) {
     return &parent->get<semantics::ObjectEntityDetails>()

diff  --git a/flang/test/Evaluate/folding12.f90 b/flang/test/Evaluate/folding12.f90
new file mode 100644
index 000000000000..657ddc6a34ae
--- /dev/null
+++ b/flang/test/Evaluate/folding12.f90
@@ -0,0 +1,163 @@
+! RUN: %S/test_folding.sh %s %t %f18
+! Test folding of structure constructors
+module m1
+  type parent_type
+    integer :: parent_field
+  end type parent_type
+  type, extends(parent_type) :: child_type
+    integer :: child_field 
+  end type child_type
+  type parent_array_type
+    integer, dimension(2) :: parent_field
+  end type parent_array_type
+  type, extends(parent_array_type) :: child_array_type
+    integer :: child_field
+  end type child_array_type
+
+  type(child_type), parameter :: child_const1 = child_type(10, 11)
+  logical, parameter :: test_child1 = child_const1%child_field == 11
+  logical, parameter :: test_parent = child_const1%parent_field == 10
+
+  type(child_type), parameter :: child_const2 = child_type(12, 13)
+  type(child_type), parameter :: array_var(2) = &
+    [child_type(14, 15), child_type(16, 17)]
+  logical, parameter :: test_array_child = array_var(2)%child_field == 17 
+  logical, parameter :: test_array_parent = array_var(2)%parent_field == 16
+
+  type array_type
+    real, dimension(3) :: real_field
+  end type array_type
+  type(array_type), parameter :: array_var2 = &
+    array_type([(real(i*i), i = 1,3)])
+  logical, parameter :: test_array_var = array_var2%real_field(2) == 4.0
+
+  type(child_type), parameter, dimension(2) :: child_const3 = &
+    [child_type(18, 19), child_type(20, 21)]
+  integer, dimension(2), parameter :: int_const4 = &
+    child_const3(:)%parent_field
+  logical, parameter :: test_child2 = int_const4(1) == 18
+
+  type(child_array_type), parameter, dimension(2) :: child_const5 = &
+    [child_array_type([22, 23], 24), child_array_type([25, 26], 27)]
+  integer, dimension(2), parameter :: int_const6 = child_const5(:)%parent_field(2)
+  logical, parameter :: test_child3 = int_const6(1) == 23 
+
+  type(child_type), parameter :: child_const7 =  child_type(28, 29)
+  type(parent_type), parameter :: parent_const8 = child_const7%parent_type
+  logical, parameter :: test_child4 = parent_const8%parent_field == 28
+
+  type(child_type), parameter :: child_const9 = &
+    child_type(parent_type(30), 31)
+  integer, parameter :: int_const10 = child_const9%parent_field
+  logical, parameter :: test_child5 = int_const10 == 30
+
+end module m1
+
+module m2
+  type grandparent_type
+    real :: grandparent_field
+  end type grandparent_type
+  type, extends(grandparent_type) :: parent_type
+    integer :: parent_field
+  end type parent_type
+  type, extends(parent_type) :: child_type
+    real :: child_field
+  end type child_type
+
+  type(child_type), parameter :: child_const1 = child_type(10.0, 11, 12.0)
+  integer, parameter :: int_const2 = &
+    child_const1%grandparent_type%grandparent_field
+  logical, parameter :: test_child1 = int_const2 == 10.0
+  integer, parameter :: int_const3 = &
+    child_const1%grandparent_field
+  logical, parameter :: test_child2 = int_const3 == 10.0
+
+  type(child_type), parameter :: child_const4 = &
+    child_type(parent_type(13.0, 14), 15.0)
+  integer, parameter :: int_const5 = &
+    child_const4%grandparent_type%grandparent_field
+  logical, parameter :: test_child3 = int_const5 == 13.0
+
+  type(child_type), parameter :: child_const6 = &
+    child_type(parent_type(grandparent_type(16.0), 17), 18.0)
+  integer, parameter :: int_const7 = &
+    child_const6%grandparent_type%grandparent_field
+  logical, parameter :: test_child4 = int_const7 == 16.0
+  integer, parameter :: int_const8 = &
+    child_const6%grandparent_field
+  logical, parameter :: test_child5 = int_const8 == 16.0
+end module m2
+
+module m3
+  ! tests that use components with default initializations and with the
+  ! components in the structure constructors in a 
diff erent order from the
+  ! declared order
+  type parent_type
+    integer :: parent_field1
+    real :: parent_field2 = 20.0
+    logical :: parent_field3
+  end type parent_type
+  type, extends(parent_type) :: child_type
+    real :: child_field1
+    logical :: child_field2 = .false.
+    integer :: child_field3
+  end type child_type
+
+  type(child_type), parameter :: child_const1 = &
+    child_type( &
+      parent_field2 = 10.0, child_field3 = 11, &
+      child_field2 = .true., parent_field3 = .false., &
+      parent_field1 = 12, child_field1 = 13.3)
+  logical, parameter :: test_child1 = child_const1%child_field1 == 13.3
+  logical, parameter :: test_child2 = child_const1%child_field2 .eqv. .true.
+  logical, parameter :: test_child3 = child_const1%child_field3 == 11
+  logical, parameter :: test_parent1 = child_const1%parent_field1 == 12
+  logical, parameter :: test_parent2 = child_const1%parent_field2 == 10.0
+  logical, parameter :: test_parent3 = child_const1%parent_field3 .eqv. .false.
+  logical, parameter :: test_parent4 = & 
+    child_const1%parent_type%parent_field1 == 12
+  logical, parameter :: test_parent5 = &
+    child_const1%parent_type%parent_field2 == 10.0
+  logical, parameter :: test_parent6 = &
+    child_const1%parent_type%parent_field3 .eqv. .false.
+
+  type(parent_type), parameter ::parent_const1 = child_const1%parent_type
+  logical, parameter :: test_parent7 = parent_const1%parent_field1 == 12
+  logical, parameter :: test_parent8 = parent_const1%parent_field2 == 10.0
+  logical, parameter :: test_parent9 = &
+    parent_const1%parent_field3 .eqv. .false.
+
+  type(child_type), parameter :: child_const2 = &
+    child_type( &
+      child_field3 = 14, parent_field3 = .true., &
+      parent_field1 = 15, child_field1 = 16.6)
+  logical, parameter :: test_child4 = child_const2%child_field1 == 16.6
+  logical, parameter :: test_child5 = child_const2%child_field2 .eqv. .false.
+  logical, parameter :: test_child6 = child_const2%child_field3 == 14
+  logical, parameter :: test_parent10 = child_const2%parent_field1 == 15
+  logical, parameter :: test_parent11 = child_const2%parent_field2 == 20.0
+  logical, parameter :: test_parent12 = child_const2%parent_field3 .eqv. .true.
+
+  type(child_type), parameter :: child_const3 = &
+    child_type(parent_type( &
+      parent_field2 = 17.7, parent_field3 = .false., parent_field1 = 18), &
+        child_field2 = .false., child_field1 = 19.9, child_field3 = 21)
+  logical, parameter :: test_child7 = child_const3%parent_field1 == 18
+  logical, parameter :: test_child8 = child_const3%parent_field2 == 17.7
+  logical, parameter :: test_child9 = child_const3%parent_field3 .eqv. .false.
+  logical, parameter :: test_child10 = child_const3%child_field1 == 19.9
+  logical, parameter :: test_child11 = child_const3%child_field2 .eqv. .false.
+  logical, parameter :: test_child12 = child_const3%child_field3 == 21
+
+  type(child_type), parameter :: child_const4 = &
+    child_type(parent_type( &
+      parent_field3 = .true., parent_field1 = 22), &
+      child_field1 = 23.4, child_field3 = 24)
+  logical, parameter :: test_child13 = child_const4%parent_field1 == 22
+  logical, parameter :: test_child14 = child_const4%parent_field2 == 20.0
+  logical, parameter :: test_child15 = child_const4%parent_field3 .eqv. .true.
+  logical, parameter :: test_child16 = child_const4%child_field1 == 23.4
+  logical, parameter :: test_child17 = child_const4%child_field2 .eqv. .false.
+  logical, parameter :: test_child18 = child_const4%child_field3 == 24
+
+end module m3


        


More information about the flang-commits mailing list