[flang-commits] [flang] [flang] Fix mod file generation of derived type initializers... (PR #70511)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Mon Oct 30 12:09:10 PDT 2023


https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/70511

>From 9e26c4505d6f9a6244187c61b94541051dea8828 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Thu, 26 Oct 2023 15:26:46 -0700
Subject: [PATCH] [flang] Fix mod file generation of derived type
 initializers...

... when the derived type used in the structure constructor(s) is
from another module and not use-associated into the current module.

This came up in a test with a derived type component default initializer
of "c_null_ptr", which is replaced with the expression
"__builtin_c_ptr(address=0_8)"; the derived type name "__builtin_c_ptr"
is not available in the current scope, and the module file would fail
semantic analysis when USE'd.

The best solution that I found was to extend module file generation
to detect this case and handle it by inserting the right USE association
to the ultimate derived type symbol, possibly with renaming to a
compiler-created name in the case of a conflict.

To implement this transformation, it was necessary to fix the utility
evaluate::CollectSymbols() to include the derived type symbol from a
structure constructor.  This involved extending the expression traversal
framework to visit the derived type spec of a structure constructor.
Extending CollectSymbols() caused a lowering test to fail mysteriously,
so I tracked down the code in PFTBuilder that didn't expect to see
a DerivedTypeDetails symbol and dealt with it there.
---
 flang/include/flang/Evaluate/constant.h   |   3 +-
 flang/include/flang/Evaluate/expression.h |   3 +-
 flang/include/flang/Evaluate/traverse.h   |  21 ++--
 flang/include/flang/Evaluate/type.h       |   3 +-
 flang/include/flang/Parser/message.h      |   6 +
 flang/lib/Evaluate/formatting.cpp         |  15 ++-
 flang/lib/Lower/PFTBuilder.cpp            |   5 +-
 flang/lib/Semantics/mod-file.cpp          | 138 +++++++++++++++++-----
 flang/lib/Semantics/mod-file.h            |   2 +
 flang/test/Semantics/modfile03.f90        |   4 +-
 flang/test/Semantics/modfile59.f90        |  30 +++++
 11 files changed, 180 insertions(+), 50 deletions(-)
 create mode 100644 flang/test/Semantics/modfile59.f90

diff --git a/flang/include/flang/Evaluate/constant.h b/flang/include/flang/Evaluate/constant.h
index 8c841918bccbe89..ee83d9fc04f3b91 100644
--- a/flang/include/flang/Evaluate/constant.h
+++ b/flang/include/flang/Evaluate/constant.h
@@ -126,7 +126,8 @@ class ConstantBase : public ConstantBounds {
   constexpr Result result() const { return result_; }
 
   constexpr DynamicType GetType() const { return result_.GetType(); }
-  llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
+  llvm::raw_ostream &AsFortran(llvm::raw_ostream &,
+      const parser::CharBlock *derivedTypeRename = nullptr) const;
 
 protected:
   std::vector<Element> Reshape(const ConstantSubscripts &) const;
diff --git a/flang/include/flang/Evaluate/expression.h b/flang/include/flang/Evaluate/expression.h
index 642ddf51168477d..64db0b88d03e587 100644
--- a/flang/include/flang/Evaluate/expression.h
+++ b/flang/include/flang/Evaluate/expression.h
@@ -735,7 +735,8 @@ class StructureConstructor {
   StructureConstructor &Add(const semantics::Symbol &, Expr<SomeType> &&);
   int Rank() const { return 0; }
   DynamicType GetType() const;
-  llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
+  llvm::raw_ostream &AsFortran(llvm::raw_ostream &,
+      const parser::CharBlock *derivedTypeRename = nullptr) const;
 
 private:
   std::optional<Expr<SomeType>> CreateParentComponent(const Symbol &) const;
diff --git a/flang/include/flang/Evaluate/traverse.h b/flang/include/flang/Evaluate/traverse.h
index 54cdb690ec0887e..8d75cc2df7247b4 100644
--- a/flang/include/flang/Evaluate/traverse.h
+++ b/flang/include/flang/Evaluate/traverse.h
@@ -100,16 +100,8 @@ template <typename Visitor, typename Result> class Traverse {
   Result operator()(const NullPointer &) const { return visitor_.Default(); }
   template <typename T> Result operator()(const Constant<T> &x) const {
     if constexpr (T::category == TypeCategory::Derived) {
-      std::optional<Result> result;
-      for (const StructureConstructorValues &map : x.values()) {
-        for (const auto &pair : map) {
-          auto value{visitor_(pair.second.value())};
-          result = result
-              ? visitor_.Combine(std::move(*result), std::move(value))
-              : std::move(value);
-        }
-      }
-      return result ? *result : visitor_.Default();
+      return visitor_.Combine(
+          visitor_(x.result().derivedTypeSpec()), CombineContents(x.values()));
     } else {
       return visitor_.Default();
     }
@@ -217,12 +209,19 @@ template <typename Visitor, typename Result> class Traverse {
       const semantics::DerivedTypeSpec::ParameterMapType::value_type &x) const {
     return visitor_(x.second);
   }
+  Result operator()(
+      const semantics::DerivedTypeSpec::ParameterMapType &x) const {
+    return CombineContents(x);
+  }
   Result operator()(const semantics::DerivedTypeSpec &x) const {
-    return CombineContents(x.parameters());
+    return Combine(x.typeSymbol(), x.parameters());
   }
   Result operator()(const StructureConstructorValues::value_type &x) const {
     return visitor_(x.second);
   }
+  Result operator()(const StructureConstructorValues &x) const {
+    return CombineContents(x);
+  }
   Result operator()(const StructureConstructor &x) const {
     return visitor_.Combine(visitor_(x.derivedTypeSpec()), CombineContents(x));
   }
diff --git a/flang/include/flang/Evaluate/type.h b/flang/include/flang/Evaluate/type.h
index ff784ef51f90244..73bff610ce80b06 100644
--- a/flang/include/flang/Evaluate/type.h
+++ b/flang/include/flang/Evaluate/type.h
@@ -268,7 +268,8 @@ const semantics::DerivedTypeSpec *GetDerivedTypeSpec(
 const semantics::DerivedTypeSpec *GetParentTypeSpec(
     const semantics::DerivedTypeSpec &);
 
-std::string DerivedTypeSpecAsFortran(const semantics::DerivedTypeSpec &);
+std::string DerivedTypeSpecAsFortran(const semantics::DerivedTypeSpec &,
+    const parser::CharBlock *derivedTypeRename = nullptr);
 
 template <TypeCategory CATEGORY, int KIND = 0> struct TypeBase {
   static constexpr TypeCategory category{CATEGORY};
diff --git a/flang/include/flang/Parser/message.h b/flang/include/flang/Parser/message.h
index 64b1298fb496974..668559aeec94780 100644
--- a/flang/include/flang/Parser/message.h
+++ b/flang/include/flang/Parser/message.h
@@ -121,6 +121,12 @@ class MessageFormattedText {
     return *this;
   }
   std::string MoveString() { return std::move(string_); }
+  bool operator==(const MessageFormattedText &that) const {
+    return severity_ == that.severity_ && string_ == that.string_;
+  }
+  bool operator!=(const MessageFormattedText &that) const {
+    return !(*this == that);
+  }
 
 private:
   void Format(const MessageFixedText *, ...);
diff --git a/flang/lib/Evaluate/formatting.cpp b/flang/lib/Evaluate/formatting.cpp
index 5684c07657e61f1..5f822bbcbb04f49 100644
--- a/flang/lib/Evaluate/formatting.cpp
+++ b/flang/lib/Evaluate/formatting.cpp
@@ -53,7 +53,7 @@ static void ShapeAsFortran(llvm::raw_ostream &o,
 
 template <typename RESULT, typename VALUE>
 llvm::raw_ostream &ConstantBase<RESULT, VALUE>::AsFortran(
-    llvm::raw_ostream &o) const {
+    llvm::raw_ostream &o, const parser::CharBlock *derivedTypeRename) const {
   bool hasNonDefaultLowerBound{printLbounds && HasNonDefaultLowerBound()};
   if (Rank() > 1 || hasNonDefaultLowerBound) {
     o << "reshape(";
@@ -85,7 +85,8 @@ llvm::raw_ostream &ConstantBase<RESULT, VALUE>::AsFortran(
         o << ".false." << '_' << Result::kind;
       }
     } else {
-      StructureConstructor{result_.derivedTypeSpec(), value}.AsFortran(o);
+      StructureConstructor{result_.derivedTypeSpec(), value}.AsFortran(
+          o, derivedTypeRename);
     }
   }
   if (Rank() > 0) {
@@ -503,8 +504,9 @@ llvm::raw_ostream &ExpressionBase<RESULT>::AsFortran(
   return o;
 }
 
-llvm::raw_ostream &StructureConstructor::AsFortran(llvm::raw_ostream &o) const {
-  o << DerivedTypeSpecAsFortran(result_.derivedTypeSpec());
+llvm::raw_ostream &StructureConstructor::AsFortran(
+    llvm::raw_ostream &o, const parser::CharBlock *derivedTypeRename) const {
+  o << DerivedTypeSpecAsFortran(result_.derivedTypeSpec(), derivedTypeRename);
   if (values_.empty()) {
     o << '(';
   } else {
@@ -566,10 +568,11 @@ std::string SomeDerived::AsFortran() const {
   }
 }
 
-std::string DerivedTypeSpecAsFortran(const semantics::DerivedTypeSpec &spec) {
+std::string DerivedTypeSpecAsFortran(const semantics::DerivedTypeSpec &spec,
+    const parser::CharBlock *derivedTypeRename) {
   std::string buf;
   llvm::raw_string_ostream ss{buf};
-  ss << spec.name().ToString();
+  ss << (derivedTypeRename ? *derivedTypeRename : spec.name()).ToString();
   char ch{'('};
   for (const auto &[name, value] : spec.parameters()) {
     ss << ch << name.ToString() << '=';
diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp
index 0946a85dcaddd88..bc207ba4b9f1254 100644
--- a/flang/lib/Lower/PFTBuilder.cpp
+++ b/flang/lib/Lower/PFTBuilder.cpp
@@ -1513,7 +1513,7 @@ struct SymbolDependenceAnalysis {
     // Derived type component symbols may be collected by "CollectSymbols"
     // below when processing something like "real :: x(derived%component)". The
     // symbol "component" has "ObjectEntityDetails", but it should not be
-    // instantiated: it is is part of "derived" that should be the only one to
+    // instantiated: it is part of "derived" that should be the only one to
     // be instantiated.
     if (sym.owner().IsDerivedType())
       return 0;
@@ -1569,7 +1569,8 @@ struct SymbolDependenceAnalysis {
       // Handle any symbols in initialization expressions.
       if (auto e = details->init())
         for (const auto &s : evaluate::CollectSymbols(*e))
-          depth = std::max(analyze(s) + 1, depth);
+          if (!s->has<semantics::DerivedTypeDetails>())
+            depth = std::max(analyze(s) + 1, depth);
     }
     adjustSize(depth + 1);
     bool global = lower::symbolIsGlobal(sym);
diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp
index 8684eb1fbd3326d..70b6bbf8b557acb 100644
--- a/flang/lib/Semantics/mod-file.cpp
+++ b/flang/lib/Semantics/mod-file.cpp
@@ -44,10 +44,11 @@ struct ModHeader {
 };
 
 static std::optional<SourceName> GetSubmoduleParent(const parser::Program &);
-static void CollectSymbols(const Scope &, SymbolVector &, SymbolVector &);
+static void CollectSymbols(const Scope &, SymbolVector &, SymbolVector &,
+    std::map<const Symbol *, SourceName> &);
 static void PutPassName(llvm::raw_ostream &, const std::optional<SourceName> &);
 static void PutInit(llvm::raw_ostream &, const Symbol &, const MaybeExpr &,
-    const parser::Expr *);
+    const parser::Expr *, const std::map<const Symbol *, SourceName> &);
 static void PutInit(llvm::raw_ostream &, const MaybeIntExpr &);
 static void PutBound(llvm::raw_ostream &, const Bound &);
 static void PutShapeSpec(llvm::raw_ostream &, const ShapeSpec &);
@@ -173,11 +174,81 @@ std::string ModFileWriter::GetAsString(const Symbol &symbol) {
   return all.str();
 }
 
+// Collect symbols from initializations that are being referenced directly
+// from other modules; they may require new USE associations.
+static void HarvestInitializerSymbols(
+    SourceOrderedSymbolSet &set, const Scope &scope) {
+  for (const auto &[_, symbol] : scope) {
+    if (symbol->has<DerivedTypeDetails>()) {
+      if (symbol->scope()) {
+        HarvestInitializerSymbols(set, *symbol->scope());
+      }
+    } else if (IsNamedConstant(*symbol) || scope.IsDerivedType()) {
+      if (const auto *object{symbol->detailsIf<ObjectEntityDetails>()}) {
+        if (object->init()) {
+          for (SymbolRef ref : evaluate::CollectSymbols(*object->init())) {
+            set.emplace(*ref);
+          }
+        }
+      } else if (const auto *proc{symbol->detailsIf<ProcEntityDetails>()}) {
+        if (proc->init() && *proc->init()) {
+          set.emplace(**proc->init());
+        }
+      }
+    }
+  }
+}
+
+void ModFileWriter::PrepareRenamings(const Scope &scope) {
+  SourceOrderedSymbolSet symbolsInInits;
+  HarvestInitializerSymbols(symbolsInInits, scope);
+  for (SymbolRef s : symbolsInInits) {
+    const Scope *sMod{FindModuleContaining(s->owner())};
+    if (!sMod) {
+      continue;
+    }
+    SourceName rename{s->name()};
+    if (const Symbol * found{scope.FindSymbol(s->name())}) {
+      if (found == &*s) {
+        continue; // available in scope
+      }
+      if (const auto *generic{found->detailsIf<GenericDetails>()}) {
+        if (generic->derivedType() == &*s || generic->specific() == &*s) {
+          continue;
+        }
+      } else if (found->has<UseDetails>()) {
+        if (&found->GetUltimate() == &*s) {
+          continue; // already use-associated with same name
+        }
+      }
+      if (&s->owner() != &found->owner()) { // Symbol needs renaming
+        rename = scope.context().SaveTempName(
+            DEREF(sMod->symbol()).name().ToString() + "$" +
+            s->name().ToString());
+      }
+    }
+    // Symbol is used in this scope but not visible under its name
+    if (sMod->parent().IsIntrinsicModules()) {
+      uses_ << "use,intrinsic::";
+    } else {
+      uses_ << "use ";
+    }
+    uses_ << DEREF(sMod->symbol()).name() << ",only:";
+    if (rename != s->name()) {
+      uses_ << rename << "=>";
+    }
+    uses_ << s->name() << '\n';
+    useExtraAttrs_ << "private::" << rename << '\n';
+    renamings_.emplace(&*s, rename);
+  }
+}
+
 // Put out the visible symbols from scope.
 void ModFileWriter::PutSymbols(const Scope &scope) {
   SymbolVector sorted;
   SymbolVector uses;
-  CollectSymbols(scope, sorted, uses);
+  PrepareRenamings(scope);
+  CollectSymbols(scope, sorted, uses, renamings_);
   std::string buf; // stuff after CONTAINS in derived type
   llvm::raw_string_ostream typeBindings{buf};
   for (const Symbol &symbol : sorted) {
@@ -404,7 +475,7 @@ void ModFileWriter::PutDECStructure(
         }
         decls_ << ref->name();
         PutShape(decls_, object->shape(), '(', ')');
-        PutInit(decls_, *ref, object->init(), nullptr);
+        PutInit(decls_, *ref, object->init(), nullptr, renamings_);
         emittedDECFields_.insert(*ref);
       } else if (any) {
         break; // any later use of this structure will use RECORD/str/
@@ -628,31 +699,30 @@ static inline SourceName NameInModuleFile(const Symbol &symbol) {
 
 // Collect the symbols of this scope sorted by their original order, not name.
 // Generics and namelists are exceptions: they are sorted after other symbols.
-void CollectSymbols(
-    const Scope &scope, SymbolVector &sorted, SymbolVector &uses) {
+void CollectSymbols(const Scope &scope, SymbolVector &sorted,
+    SymbolVector &uses, std::map<const Symbol *, SourceName> &renamings) {
   SymbolVector namelist, generics;
-  std::size_t commonSize{scope.commonBlocks().size()};
   auto symbols{scope.GetSymbols()};
+  std::size_t commonSize{scope.commonBlocks().size()};
   sorted.reserve(symbols.size() + commonSize);
   for (SymbolRef symbol : symbols) {
-    if (!symbol->test(Symbol::Flag::ParentComp)) {
-      if (symbol->has<NamelistDetails>()) {
-        namelist.push_back(symbol);
-      } else if (const auto *generic{symbol->detailsIf<GenericDetails>()}) {
-        if (generic->specific() &&
-            &generic->specific()->owner() == &symbol->owner()) {
-          sorted.push_back(*generic->specific());
-        } else if (generic->derivedType() &&
-            &generic->derivedType()->owner() == &symbol->owner()) {
-          sorted.push_back(*generic->derivedType());
-        }
-        generics.push_back(symbol);
-      } else {
-        sorted.push_back(symbol);
-      }
-      if (const auto *details{symbol->detailsIf<GenericDetails>()}) {
-        uses.insert(uses.end(), details->uses().begin(), details->uses().end());
+    if (symbol->test(Symbol::Flag::ParentComp)) {
+    } else if (symbol->has<NamelistDetails>()) {
+      namelist.push_back(symbol);
+    } else if (const auto *generic{symbol->detailsIf<GenericDetails>()}) {
+      if (generic->specific() &&
+          &generic->specific()->owner() == &symbol->owner()) {
+        sorted.push_back(*generic->specific());
+      } else if (generic->derivedType() &&
+          &generic->derivedType()->owner() == &symbol->owner()) {
+        sorted.push_back(*generic->derivedType());
       }
+      generics.push_back(symbol);
+    } else {
+      sorted.push_back(symbol);
+    }
+    if (const auto *details{symbol->detailsIf<GenericDetails>()}) {
+      uses.insert(uses.end(), details->uses().begin(), details->uses().end());
     }
   }
   // Sort most symbols by name: use of Symbol::ReplaceName ensures the source
@@ -734,7 +804,8 @@ void ModFileWriter::PutObjectEntity(
       getSymbolAttrsToWrite(symbol));
   PutShape(os, details.shape(), '(', ')');
   PutShape(os, details.coshape(), '[', ']');
-  PutInit(os, symbol, details.init(), details.unanalyzedPDTComponentInit());
+  PutInit(os, symbol, details.init(), details.unanalyzedPDTComponentInit(),
+      renamings_);
   os << '\n';
   if (auto tkr{GetIgnoreTKR(symbol)}; !tkr.empty()) {
     os << "!dir$ ignore_tkr(";
@@ -828,12 +899,25 @@ void ModFileWriter::PutTypeParam(llvm::raw_ostream &os, const Symbol &symbol) {
 }
 
 void PutInit(llvm::raw_ostream &os, const Symbol &symbol, const MaybeExpr &init,
-    const parser::Expr *unanalyzed) {
-  if (symbol.attrs().test(Attr::PARAMETER) || symbol.owner().IsDerivedType()) {
+    const parser::Expr *unanalyzed,
+    const std::map<const Symbol *, SourceName> &renamings) {
+  if (IsNamedConstant(symbol) || symbol.owner().IsDerivedType()) {
     const char *assign{symbol.attrs().test(Attr::POINTER) ? "=>" : "="};
     if (unanalyzed) {
       parser::Unparse(os << assign, *unanalyzed);
     } else if (init) {
+      if (const auto *dtConst{
+              evaluate::UnwrapExpr<evaluate::Constant<evaluate::SomeDerived>>(
+                  *init)}) {
+        const Symbol &dtSym{dtConst->result().derivedTypeSpec().typeSymbol()};
+        if (auto iter{renamings.find(&dtSym)}; iter != renamings.end()) {
+          // Initializer is a constant whose derived type's name has
+          // been brought into scope from a module under a new name
+          // to avoid a conflict.
+          dtConst->AsFortran(os << assign, &iter->second);
+          return;
+        }
+      }
       init->AsFortran(os << assign);
     }
   }
diff --git a/flang/lib/Semantics/mod-file.h b/flang/lib/Semantics/mod-file.h
index a2366a222d360ab..5be117153dd4d1d 100644
--- a/flang/lib/Semantics/mod-file.h
+++ b/flang/lib/Semantics/mod-file.h
@@ -52,11 +52,13 @@ class ModFileWriter {
   llvm::raw_string_ostream decls_{declsBuf_};
   llvm::raw_string_ostream contains_{containsBuf_};
   bool isSubmodule_{false};
+  std::map<const Symbol *, SourceName> renamings_;
 
   void WriteAll(const Scope &);
   void WriteOne(const Scope &);
   void Write(const Symbol &);
   std::string GetAsString(const Symbol &);
+  void PrepareRenamings(const Scope &);
   void PutSymbols(const Scope &);
   // Returns true if a derived type with bindings and "contains" was emitted
   bool PutComponents(const Symbol &);
diff --git a/flang/test/Semantics/modfile03.f90 b/flang/test/Semantics/modfile03.f90
index 8e2eb43e7e549fb..db0caeab973f853 100644
--- a/flang/test/Semantics/modfile03.f90
+++ b/flang/test/Semantics/modfile03.f90
@@ -135,8 +135,10 @@ module m6d
 end
 !Expect: m6d.mod
 !module m6d
+! use m6a,only:t1
 ! use m6a,only:t2=>t1
-! type(t2),parameter::p=t2()
+! private::t1
+! type(t2),parameter::p=t1()
 !end
 
 module m6e
diff --git a/flang/test/Semantics/modfile59.f90 b/flang/test/Semantics/modfile59.f90
new file mode 100644
index 000000000000000..873451f72795cb7
--- /dev/null
+++ b/flang/test/Semantics/modfile59.f90
@@ -0,0 +1,30 @@
+! RUN: %python %S/test_modfile.py %s %flang_fc1
+! Test derived type renaming in initializers necessary to avoid
+! clashing with local names
+module m
+  use, intrinsic :: iso_c_binding, only: &
+    c_ptr, c_funptr, c_null_ptr, c_null_funptr
+  real, private :: __builtin_c_ptr, __builtin_c_funptr
+  type mydt
+    type(c_funptr) :: component = c_null_funptr
+  end type
+  type(c_ptr), parameter :: namedConst = c_null_ptr
+end
+
+!Expect: m.mod
+!module m
+!use,intrinsic::__fortran_builtins,only:__fortran_builtins$__builtin_c_ptr=>__builtin_c_ptr
+!use,intrinsic::__fortran_builtins,only:__fortran_builtins$__builtin_c_funptr=>__builtin_c_funptr
+!use,intrinsic::iso_c_binding,only:c_ptr
+!use,intrinsic::iso_c_binding,only:c_funptr
+!use,intrinsic::iso_c_binding,only:c_null_ptr
+!use,intrinsic::iso_c_binding,only:c_null_funptr
+!private::__fortran_builtins$__builtin_c_ptr
+!private::__fortran_builtins$__builtin_c_funptr
+!real(4),private::__builtin_c_ptr
+!real(4),private::__builtin_c_funptr
+!type::mydt
+!type(c_funptr)::component=__fortran_builtins$__builtin_c_funptr(__address=0_8)
+!end type
+!type(c_ptr),parameter::namedconst=__fortran_builtins$__builtin_c_ptr(__address=0_8)
+!end



More information about the flang-commits mailing list