[flang-commits] [flang] [flang] Fix mod file generation of derived type initializers... (PR #70511)
Peter Klausler via flang-commits
flang-commits at lists.llvm.org
Mon Oct 30 12:09:10 PDT 2023
https://github.com/klausler updated https://github.com/llvm/llvm-project/pull/70511
>From 9e26c4505d6f9a6244187c61b94541051dea8828 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Thu, 26 Oct 2023 15:26:46 -0700
Subject: [PATCH] [flang] Fix mod file generation of derived type
initializers...
... when the derived type used in the structure constructor(s) is
from another module and not use-associated into the current module.
This came up in a test with a derived type component default initializer
of "c_null_ptr", which is replaced with the expression
"__builtin_c_ptr(address=0_8)"; the derived type name "__builtin_c_ptr"
is not available in the current scope, and the module file would fail
semantic analysis when USE'd.
The best solution that I found was to extend module file generation
to detect this case and handle it by inserting the right USE association
to the ultimate derived type symbol, possibly with renaming to a
compiler-created name in the case of a conflict.
To implement this transformation, it was necessary to fix the utility
evaluate::CollectSymbols() to include the derived type symbol from a
structure constructor. This involved extending the expression traversal
framework to visit the derived type spec of a structure constructor.
Extending CollectSymbols() caused a lowering test to fail mysteriously,
so I tracked down the code in PFTBuilder that didn't expect to see
a DerivedTypeDetails symbol and dealt with it there.
---
flang/include/flang/Evaluate/constant.h | 3 +-
flang/include/flang/Evaluate/expression.h | 3 +-
flang/include/flang/Evaluate/traverse.h | 21 ++--
flang/include/flang/Evaluate/type.h | 3 +-
flang/include/flang/Parser/message.h | 6 +
flang/lib/Evaluate/formatting.cpp | 15 ++-
flang/lib/Lower/PFTBuilder.cpp | 5 +-
flang/lib/Semantics/mod-file.cpp | 138 +++++++++++++++++-----
flang/lib/Semantics/mod-file.h | 2 +
flang/test/Semantics/modfile03.f90 | 4 +-
flang/test/Semantics/modfile59.f90 | 30 +++++
11 files changed, 180 insertions(+), 50 deletions(-)
create mode 100644 flang/test/Semantics/modfile59.f90
diff --git a/flang/include/flang/Evaluate/constant.h b/flang/include/flang/Evaluate/constant.h
index 8c841918bccbe89..ee83d9fc04f3b91 100644
--- a/flang/include/flang/Evaluate/constant.h
+++ b/flang/include/flang/Evaluate/constant.h
@@ -126,7 +126,8 @@ class ConstantBase : public ConstantBounds {
constexpr Result result() const { return result_; }
constexpr DynamicType GetType() const { return result_.GetType(); }
- llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
+ llvm::raw_ostream &AsFortran(llvm::raw_ostream &,
+ const parser::CharBlock *derivedTypeRename = nullptr) const;
protected:
std::vector<Element> Reshape(const ConstantSubscripts &) const;
diff --git a/flang/include/flang/Evaluate/expression.h b/flang/include/flang/Evaluate/expression.h
index 642ddf51168477d..64db0b88d03e587 100644
--- a/flang/include/flang/Evaluate/expression.h
+++ b/flang/include/flang/Evaluate/expression.h
@@ -735,7 +735,8 @@ class StructureConstructor {
StructureConstructor &Add(const semantics::Symbol &, Expr<SomeType> &&);
int Rank() const { return 0; }
DynamicType GetType() const;
- llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
+ llvm::raw_ostream &AsFortran(llvm::raw_ostream &,
+ const parser::CharBlock *derivedTypeRename = nullptr) const;
private:
std::optional<Expr<SomeType>> CreateParentComponent(const Symbol &) const;
diff --git a/flang/include/flang/Evaluate/traverse.h b/flang/include/flang/Evaluate/traverse.h
index 54cdb690ec0887e..8d75cc2df7247b4 100644
--- a/flang/include/flang/Evaluate/traverse.h
+++ b/flang/include/flang/Evaluate/traverse.h
@@ -100,16 +100,8 @@ template <typename Visitor, typename Result> class Traverse {
Result operator()(const NullPointer &) const { return visitor_.Default(); }
template <typename T> Result operator()(const Constant<T> &x) const {
if constexpr (T::category == TypeCategory::Derived) {
- std::optional<Result> result;
- for (const StructureConstructorValues &map : x.values()) {
- for (const auto &pair : map) {
- auto value{visitor_(pair.second.value())};
- result = result
- ? visitor_.Combine(std::move(*result), std::move(value))
- : std::move(value);
- }
- }
- return result ? *result : visitor_.Default();
+ return visitor_.Combine(
+ visitor_(x.result().derivedTypeSpec()), CombineContents(x.values()));
} else {
return visitor_.Default();
}
@@ -217,12 +209,19 @@ template <typename Visitor, typename Result> class Traverse {
const semantics::DerivedTypeSpec::ParameterMapType::value_type &x) const {
return visitor_(x.second);
}
+ Result operator()(
+ const semantics::DerivedTypeSpec::ParameterMapType &x) const {
+ return CombineContents(x);
+ }
Result operator()(const semantics::DerivedTypeSpec &x) const {
- return CombineContents(x.parameters());
+ return Combine(x.typeSymbol(), x.parameters());
}
Result operator()(const StructureConstructorValues::value_type &x) const {
return visitor_(x.second);
}
+ Result operator()(const StructureConstructorValues &x) const {
+ return CombineContents(x);
+ }
Result operator()(const StructureConstructor &x) const {
return visitor_.Combine(visitor_(x.derivedTypeSpec()), CombineContents(x));
}
diff --git a/flang/include/flang/Evaluate/type.h b/flang/include/flang/Evaluate/type.h
index ff784ef51f90244..73bff610ce80b06 100644
--- a/flang/include/flang/Evaluate/type.h
+++ b/flang/include/flang/Evaluate/type.h
@@ -268,7 +268,8 @@ const semantics::DerivedTypeSpec *GetDerivedTypeSpec(
const semantics::DerivedTypeSpec *GetParentTypeSpec(
const semantics::DerivedTypeSpec &);
-std::string DerivedTypeSpecAsFortran(const semantics::DerivedTypeSpec &);
+std::string DerivedTypeSpecAsFortran(const semantics::DerivedTypeSpec &,
+ const parser::CharBlock *derivedTypeRename = nullptr);
template <TypeCategory CATEGORY, int KIND = 0> struct TypeBase {
static constexpr TypeCategory category{CATEGORY};
diff --git a/flang/include/flang/Parser/message.h b/flang/include/flang/Parser/message.h
index 64b1298fb496974..668559aeec94780 100644
--- a/flang/include/flang/Parser/message.h
+++ b/flang/include/flang/Parser/message.h
@@ -121,6 +121,12 @@ class MessageFormattedText {
return *this;
}
std::string MoveString() { return std::move(string_); }
+ bool operator==(const MessageFormattedText &that) const {
+ return severity_ == that.severity_ && string_ == that.string_;
+ }
+ bool operator!=(const MessageFormattedText &that) const {
+ return !(*this == that);
+ }
private:
void Format(const MessageFixedText *, ...);
diff --git a/flang/lib/Evaluate/formatting.cpp b/flang/lib/Evaluate/formatting.cpp
index 5684c07657e61f1..5f822bbcbb04f49 100644
--- a/flang/lib/Evaluate/formatting.cpp
+++ b/flang/lib/Evaluate/formatting.cpp
@@ -53,7 +53,7 @@ static void ShapeAsFortran(llvm::raw_ostream &o,
template <typename RESULT, typename VALUE>
llvm::raw_ostream &ConstantBase<RESULT, VALUE>::AsFortran(
- llvm::raw_ostream &o) const {
+ llvm::raw_ostream &o, const parser::CharBlock *derivedTypeRename) const {
bool hasNonDefaultLowerBound{printLbounds && HasNonDefaultLowerBound()};
if (Rank() > 1 || hasNonDefaultLowerBound) {
o << "reshape(";
@@ -85,7 +85,8 @@ llvm::raw_ostream &ConstantBase<RESULT, VALUE>::AsFortran(
o << ".false." << '_' << Result::kind;
}
} else {
- StructureConstructor{result_.derivedTypeSpec(), value}.AsFortran(o);
+ StructureConstructor{result_.derivedTypeSpec(), value}.AsFortran(
+ o, derivedTypeRename);
}
}
if (Rank() > 0) {
@@ -503,8 +504,9 @@ llvm::raw_ostream &ExpressionBase<RESULT>::AsFortran(
return o;
}
-llvm::raw_ostream &StructureConstructor::AsFortran(llvm::raw_ostream &o) const {
- o << DerivedTypeSpecAsFortran(result_.derivedTypeSpec());
+llvm::raw_ostream &StructureConstructor::AsFortran(
+ llvm::raw_ostream &o, const parser::CharBlock *derivedTypeRename) const {
+ o << DerivedTypeSpecAsFortran(result_.derivedTypeSpec(), derivedTypeRename);
if (values_.empty()) {
o << '(';
} else {
@@ -566,10 +568,11 @@ std::string SomeDerived::AsFortran() const {
}
}
-std::string DerivedTypeSpecAsFortran(const semantics::DerivedTypeSpec &spec) {
+std::string DerivedTypeSpecAsFortran(const semantics::DerivedTypeSpec &spec,
+ const parser::CharBlock *derivedTypeRename) {
std::string buf;
llvm::raw_string_ostream ss{buf};
- ss << spec.name().ToString();
+ ss << (derivedTypeRename ? *derivedTypeRename : spec.name()).ToString();
char ch{'('};
for (const auto &[name, value] : spec.parameters()) {
ss << ch << name.ToString() << '=';
diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp
index 0946a85dcaddd88..bc207ba4b9f1254 100644
--- a/flang/lib/Lower/PFTBuilder.cpp
+++ b/flang/lib/Lower/PFTBuilder.cpp
@@ -1513,7 +1513,7 @@ struct SymbolDependenceAnalysis {
// Derived type component symbols may be collected by "CollectSymbols"
// below when processing something like "real :: x(derived%component)". The
// symbol "component" has "ObjectEntityDetails", but it should not be
- // instantiated: it is is part of "derived" that should be the only one to
+ // instantiated: it is part of "derived" that should be the only one to
// be instantiated.
if (sym.owner().IsDerivedType())
return 0;
@@ -1569,7 +1569,8 @@ struct SymbolDependenceAnalysis {
// Handle any symbols in initialization expressions.
if (auto e = details->init())
for (const auto &s : evaluate::CollectSymbols(*e))
- depth = std::max(analyze(s) + 1, depth);
+ if (!s->has<semantics::DerivedTypeDetails>())
+ depth = std::max(analyze(s) + 1, depth);
}
adjustSize(depth + 1);
bool global = lower::symbolIsGlobal(sym);
diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp
index 8684eb1fbd3326d..70b6bbf8b557acb 100644
--- a/flang/lib/Semantics/mod-file.cpp
+++ b/flang/lib/Semantics/mod-file.cpp
@@ -44,10 +44,11 @@ struct ModHeader {
};
static std::optional<SourceName> GetSubmoduleParent(const parser::Program &);
-static void CollectSymbols(const Scope &, SymbolVector &, SymbolVector &);
+static void CollectSymbols(const Scope &, SymbolVector &, SymbolVector &,
+ std::map<const Symbol *, SourceName> &);
static void PutPassName(llvm::raw_ostream &, const std::optional<SourceName> &);
static void PutInit(llvm::raw_ostream &, const Symbol &, const MaybeExpr &,
- const parser::Expr *);
+ const parser::Expr *, const std::map<const Symbol *, SourceName> &);
static void PutInit(llvm::raw_ostream &, const MaybeIntExpr &);
static void PutBound(llvm::raw_ostream &, const Bound &);
static void PutShapeSpec(llvm::raw_ostream &, const ShapeSpec &);
@@ -173,11 +174,81 @@ std::string ModFileWriter::GetAsString(const Symbol &symbol) {
return all.str();
}
+// Collect symbols from initializations that are being referenced directly
+// from other modules; they may require new USE associations.
+static void HarvestInitializerSymbols(
+ SourceOrderedSymbolSet &set, const Scope &scope) {
+ for (const auto &[_, symbol] : scope) {
+ if (symbol->has<DerivedTypeDetails>()) {
+ if (symbol->scope()) {
+ HarvestInitializerSymbols(set, *symbol->scope());
+ }
+ } else if (IsNamedConstant(*symbol) || scope.IsDerivedType()) {
+ if (const auto *object{symbol->detailsIf<ObjectEntityDetails>()}) {
+ if (object->init()) {
+ for (SymbolRef ref : evaluate::CollectSymbols(*object->init())) {
+ set.emplace(*ref);
+ }
+ }
+ } else if (const auto *proc{symbol->detailsIf<ProcEntityDetails>()}) {
+ if (proc->init() && *proc->init()) {
+ set.emplace(**proc->init());
+ }
+ }
+ }
+ }
+}
+
+void ModFileWriter::PrepareRenamings(const Scope &scope) {
+ SourceOrderedSymbolSet symbolsInInits;
+ HarvestInitializerSymbols(symbolsInInits, scope);
+ for (SymbolRef s : symbolsInInits) {
+ const Scope *sMod{FindModuleContaining(s->owner())};
+ if (!sMod) {
+ continue;
+ }
+ SourceName rename{s->name()};
+ if (const Symbol * found{scope.FindSymbol(s->name())}) {
+ if (found == &*s) {
+ continue; // available in scope
+ }
+ if (const auto *generic{found->detailsIf<GenericDetails>()}) {
+ if (generic->derivedType() == &*s || generic->specific() == &*s) {
+ continue;
+ }
+ } else if (found->has<UseDetails>()) {
+ if (&found->GetUltimate() == &*s) {
+ continue; // already use-associated with same name
+ }
+ }
+ if (&s->owner() != &found->owner()) { // Symbol needs renaming
+ rename = scope.context().SaveTempName(
+ DEREF(sMod->symbol()).name().ToString() + "$" +
+ s->name().ToString());
+ }
+ }
+ // Symbol is used in this scope but not visible under its name
+ if (sMod->parent().IsIntrinsicModules()) {
+ uses_ << "use,intrinsic::";
+ } else {
+ uses_ << "use ";
+ }
+ uses_ << DEREF(sMod->symbol()).name() << ",only:";
+ if (rename != s->name()) {
+ uses_ << rename << "=>";
+ }
+ uses_ << s->name() << '\n';
+ useExtraAttrs_ << "private::" << rename << '\n';
+ renamings_.emplace(&*s, rename);
+ }
+}
+
// Put out the visible symbols from scope.
void ModFileWriter::PutSymbols(const Scope &scope) {
SymbolVector sorted;
SymbolVector uses;
- CollectSymbols(scope, sorted, uses);
+ PrepareRenamings(scope);
+ CollectSymbols(scope, sorted, uses, renamings_);
std::string buf; // stuff after CONTAINS in derived type
llvm::raw_string_ostream typeBindings{buf};
for (const Symbol &symbol : sorted) {
@@ -404,7 +475,7 @@ void ModFileWriter::PutDECStructure(
}
decls_ << ref->name();
PutShape(decls_, object->shape(), '(', ')');
- PutInit(decls_, *ref, object->init(), nullptr);
+ PutInit(decls_, *ref, object->init(), nullptr, renamings_);
emittedDECFields_.insert(*ref);
} else if (any) {
break; // any later use of this structure will use RECORD/str/
@@ -628,31 +699,30 @@ static inline SourceName NameInModuleFile(const Symbol &symbol) {
// Collect the symbols of this scope sorted by their original order, not name.
// Generics and namelists are exceptions: they are sorted after other symbols.
-void CollectSymbols(
- const Scope &scope, SymbolVector &sorted, SymbolVector &uses) {
+void CollectSymbols(const Scope &scope, SymbolVector &sorted,
+ SymbolVector &uses, std::map<const Symbol *, SourceName> &renamings) {
SymbolVector namelist, generics;
- std::size_t commonSize{scope.commonBlocks().size()};
auto symbols{scope.GetSymbols()};
+ std::size_t commonSize{scope.commonBlocks().size()};
sorted.reserve(symbols.size() + commonSize);
for (SymbolRef symbol : symbols) {
- if (!symbol->test(Symbol::Flag::ParentComp)) {
- if (symbol->has<NamelistDetails>()) {
- namelist.push_back(symbol);
- } else if (const auto *generic{symbol->detailsIf<GenericDetails>()}) {
- if (generic->specific() &&
- &generic->specific()->owner() == &symbol->owner()) {
- sorted.push_back(*generic->specific());
- } else if (generic->derivedType() &&
- &generic->derivedType()->owner() == &symbol->owner()) {
- sorted.push_back(*generic->derivedType());
- }
- generics.push_back(symbol);
- } else {
- sorted.push_back(symbol);
- }
- if (const auto *details{symbol->detailsIf<GenericDetails>()}) {
- uses.insert(uses.end(), details->uses().begin(), details->uses().end());
+ if (symbol->test(Symbol::Flag::ParentComp)) {
+ } else if (symbol->has<NamelistDetails>()) {
+ namelist.push_back(symbol);
+ } else if (const auto *generic{symbol->detailsIf<GenericDetails>()}) {
+ if (generic->specific() &&
+ &generic->specific()->owner() == &symbol->owner()) {
+ sorted.push_back(*generic->specific());
+ } else if (generic->derivedType() &&
+ &generic->derivedType()->owner() == &symbol->owner()) {
+ sorted.push_back(*generic->derivedType());
}
+ generics.push_back(symbol);
+ } else {
+ sorted.push_back(symbol);
+ }
+ if (const auto *details{symbol->detailsIf<GenericDetails>()}) {
+ uses.insert(uses.end(), details->uses().begin(), details->uses().end());
}
}
// Sort most symbols by name: use of Symbol::ReplaceName ensures the source
@@ -734,7 +804,8 @@ void ModFileWriter::PutObjectEntity(
getSymbolAttrsToWrite(symbol));
PutShape(os, details.shape(), '(', ')');
PutShape(os, details.coshape(), '[', ']');
- PutInit(os, symbol, details.init(), details.unanalyzedPDTComponentInit());
+ PutInit(os, symbol, details.init(), details.unanalyzedPDTComponentInit(),
+ renamings_);
os << '\n';
if (auto tkr{GetIgnoreTKR(symbol)}; !tkr.empty()) {
os << "!dir$ ignore_tkr(";
@@ -828,12 +899,25 @@ void ModFileWriter::PutTypeParam(llvm::raw_ostream &os, const Symbol &symbol) {
}
void PutInit(llvm::raw_ostream &os, const Symbol &symbol, const MaybeExpr &init,
- const parser::Expr *unanalyzed) {
- if (symbol.attrs().test(Attr::PARAMETER) || symbol.owner().IsDerivedType()) {
+ const parser::Expr *unanalyzed,
+ const std::map<const Symbol *, SourceName> &renamings) {
+ if (IsNamedConstant(symbol) || symbol.owner().IsDerivedType()) {
const char *assign{symbol.attrs().test(Attr::POINTER) ? "=>" : "="};
if (unanalyzed) {
parser::Unparse(os << assign, *unanalyzed);
} else if (init) {
+ if (const auto *dtConst{
+ evaluate::UnwrapExpr<evaluate::Constant<evaluate::SomeDerived>>(
+ *init)}) {
+ const Symbol &dtSym{dtConst->result().derivedTypeSpec().typeSymbol()};
+ if (auto iter{renamings.find(&dtSym)}; iter != renamings.end()) {
+ // Initializer is a constant whose derived type's name has
+ // been brought into scope from a module under a new name
+ // to avoid a conflict.
+ dtConst->AsFortran(os << assign, &iter->second);
+ return;
+ }
+ }
init->AsFortran(os << assign);
}
}
diff --git a/flang/lib/Semantics/mod-file.h b/flang/lib/Semantics/mod-file.h
index a2366a222d360ab..5be117153dd4d1d 100644
--- a/flang/lib/Semantics/mod-file.h
+++ b/flang/lib/Semantics/mod-file.h
@@ -52,11 +52,13 @@ class ModFileWriter {
llvm::raw_string_ostream decls_{declsBuf_};
llvm::raw_string_ostream contains_{containsBuf_};
bool isSubmodule_{false};
+ std::map<const Symbol *, SourceName> renamings_;
void WriteAll(const Scope &);
void WriteOne(const Scope &);
void Write(const Symbol &);
std::string GetAsString(const Symbol &);
+ void PrepareRenamings(const Scope &);
void PutSymbols(const Scope &);
// Returns true if a derived type with bindings and "contains" was emitted
bool PutComponents(const Symbol &);
diff --git a/flang/test/Semantics/modfile03.f90 b/flang/test/Semantics/modfile03.f90
index 8e2eb43e7e549fb..db0caeab973f853 100644
--- a/flang/test/Semantics/modfile03.f90
+++ b/flang/test/Semantics/modfile03.f90
@@ -135,8 +135,10 @@ module m6d
end
!Expect: m6d.mod
!module m6d
+! use m6a,only:t1
! use m6a,only:t2=>t1
-! type(t2),parameter::p=t2()
+! private::t1
+! type(t2),parameter::p=t1()
!end
module m6e
diff --git a/flang/test/Semantics/modfile59.f90 b/flang/test/Semantics/modfile59.f90
new file mode 100644
index 000000000000000..873451f72795cb7
--- /dev/null
+++ b/flang/test/Semantics/modfile59.f90
@@ -0,0 +1,30 @@
+! RUN: %python %S/test_modfile.py %s %flang_fc1
+! Test derived type renaming in initializers necessary to avoid
+! clashing with local names
+module m
+ use, intrinsic :: iso_c_binding, only: &
+ c_ptr, c_funptr, c_null_ptr, c_null_funptr
+ real, private :: __builtin_c_ptr, __builtin_c_funptr
+ type mydt
+ type(c_funptr) :: component = c_null_funptr
+ end type
+ type(c_ptr), parameter :: namedConst = c_null_ptr
+end
+
+!Expect: m.mod
+!module m
+!use,intrinsic::__fortran_builtins,only:__fortran_builtins$__builtin_c_ptr=>__builtin_c_ptr
+!use,intrinsic::__fortran_builtins,only:__fortran_builtins$__builtin_c_funptr=>__builtin_c_funptr
+!use,intrinsic::iso_c_binding,only:c_ptr
+!use,intrinsic::iso_c_binding,only:c_funptr
+!use,intrinsic::iso_c_binding,only:c_null_ptr
+!use,intrinsic::iso_c_binding,only:c_null_funptr
+!private::__fortran_builtins$__builtin_c_ptr
+!private::__fortran_builtins$__builtin_c_funptr
+!real(4),private::__builtin_c_ptr
+!real(4),private::__builtin_c_funptr
+!type::mydt
+!type(c_funptr)::component=__fortran_builtins$__builtin_c_funptr(__address=0_8)
+!end type
+!type(c_ptr),parameter::namedconst=__fortran_builtins$__builtin_c_ptr(__address=0_8)
+!end
More information about the flang-commits
mailing list