[flang-commits] [flang] [llvm] [flang] Improve runtime SAME_TYPE_AS() (PR #135670)
via flang-commits
flang-commits at lists.llvm.org
Mon Apr 14 13:26:59 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-semantics
Author: Peter Klausler (klausler)
<details>
<summary>Changes</summary>
The present implementation of the intrinsic function SAME_TYPE_AS() yields false positive .TRUE. results for distinct derived types that happen to have the same name.
Replace with an implementation that can now depend on derived type information records being the same type if and only if they are at the same location, or are PDT instantiations of the same uninstantiated derived type. And ensure that the derived type information includes references from instantiated PDTs to their original types. (The derived type information format supports these references already, but they were not being set, perhaps because the current faulty SAME_TYPE_AS implementation didn't need them, and nothing else does.)
Fixes https://github.com/llvm/llvm-project/issues/135580.
---
Full diff: https://github.com/llvm/llvm-project/pull/135670.diff
4 Files Affected:
- (modified) flang-rt/include/flang-rt/runtime/type-info.h (+5-1)
- (modified) flang-rt/lib/runtime/derived-api.cpp (+14-36)
- (modified) flang/lib/Semantics/runtime-type-info.cpp (+14-9)
- (modified) flang/test/Semantics/typeinfo01.f90 (+2-1)
``````````diff
diff --git a/flang-rt/include/flang-rt/runtime/type-info.h b/flang-rt/include/flang-rt/runtime/type-info.h
index 9891fcecdcb25..5e79efde164f2 100644
--- a/flang-rt/include/flang-rt/runtime/type-info.h
+++ b/flang-rt/include/flang-rt/runtime/type-info.h
@@ -210,9 +210,13 @@ class DerivedType {
}
RT_API_ATTRS const Descriptor &name() const { return name_.descriptor(); }
RT_API_ATTRS std::uint64_t sizeInBytes() const { return sizeInBytes_; }
- RT_API_ATTRS const Descriptor &uninstatiated() const {
+ RT_API_ATTRS const Descriptor &uninstantiated() const {
return uninstantiated_.descriptor();
}
+ RT_API_ATTRS const DerivedType *uninstantiatedType() const {
+ return reinterpret_cast<const DerivedType *>(
+ uninstantiated().raw().base_addr);
+ }
RT_API_ATTRS const Descriptor &kindParameter() const {
return kindParameter_.descriptor();
}
diff --git a/flang-rt/lib/runtime/derived-api.cpp b/flang-rt/lib/runtime/derived-api.cpp
index 884fa8ee7d095..bb08e0397fe9c 100644
--- a/flang-rt/lib/runtime/derived-api.cpp
+++ b/flang-rt/lib/runtime/derived-api.cpp
@@ -83,27 +83,6 @@ bool RTDEF(ClassIs)(
return false;
}
-static RT_API_ATTRS bool CompareDerivedTypeNames(
- const Descriptor &a, const Descriptor &b) {
- if (a.raw().version == CFI_VERSION &&
- a.type() == TypeCode{TypeCategory::Character, 1} &&
- a.ElementBytes() > 0 && a.rank() == 0 && a.OffsetElement() != nullptr &&
- a.raw().version == CFI_VERSION &&
- b.type() == TypeCode{TypeCategory::Character, 1} &&
- b.ElementBytes() > 0 && b.rank() == 0 && b.OffsetElement() != nullptr &&
- a.ElementBytes() == b.ElementBytes() &&
- Fortran::runtime::memcmp(
- a.OffsetElement(), b.OffsetElement(), a.ElementBytes()) == 0) {
- return true;
- }
- return false;
-}
-
-inline RT_API_ATTRS bool CompareDerivedType(
- const typeInfo::DerivedType *a, const typeInfo::DerivedType *b) {
- return a == b || CompareDerivedTypeNames(a->name(), b->name());
-}
-
static RT_API_ATTRS const typeInfo::DerivedType *GetDerivedType(
const Descriptor &desc) {
if (const DescriptorAddendum * addendum{desc.Addendum()}) {
@@ -121,22 +100,21 @@ bool RTDEF(SameTypeAs)(const Descriptor &a, const Descriptor &b) {
(bType != CFI_type_struct && bType != CFI_type_other)) {
// If either type is intrinsic, they must match.
return aType == bType;
- } else {
- const typeInfo::DerivedType *derivedTypeA{GetDerivedType(a)};
- const typeInfo::DerivedType *derivedTypeB{GetDerivedType(b)};
- if (derivedTypeA == nullptr || derivedTypeB == nullptr) {
- // Unallocated/disassociated CLASS(*) never matches.
- return false;
- } else if (derivedTypeA == derivedTypeB) {
- // Exact match of derived type.
- return true;
- } else {
- // Otherwise compare with the name. Note 16.29 kind type parameters are
- // not considered in the test.
- return CompareDerivedTypeNames(
- derivedTypeA->name(), derivedTypeB->name());
+ } else if (const typeInfo::DerivedType * derivedTypeA{GetDerivedType(a)}) {
+ if (const typeInfo::DerivedType * derivedTypeB{GetDerivedType(b)}) {
+ if (derivedTypeA == derivedTypeB) {
+ return true;
+ } else if (const typeInfo::DerivedType *
+ uninstDerivedTypeA{derivedTypeA->uninstantiatedType()}) {
+ // There are KIND type parameters, are these the same type if those
+ // are ignored?
+ const typeInfo::DerivedType *uninstDerivedTypeB{
+ derivedTypeB->uninstantiatedType()};
+ return uninstDerivedTypeA == uninstDerivedTypeB;
+ }
}
}
+ return false;
}
bool RTDEF(ExtendsTypeOf)(const Descriptor &a, const Descriptor &mold) {
@@ -155,7 +133,7 @@ bool RTDEF(ExtendsTypeOf)(const Descriptor &a, const Descriptor &mold) {
// dynamic type of MOLD.
for (const typeInfo::DerivedType *derivedTypeA{GetDerivedType(a)};
derivedTypeA; derivedTypeA = derivedTypeA->GetParentType()) {
- if (CompareDerivedType(derivedTypeA, derivedTypeMold)) {
+ if (derivedTypeA == derivedTypeMold) {
return true;
}
}
diff --git a/flang/lib/Semantics/runtime-type-info.cpp b/flang/lib/Semantics/runtime-type-info.cpp
index 9f3eb5fbe11a1..98295f3705a71 100644
--- a/flang/lib/Semantics/runtime-type-info.cpp
+++ b/flang/lib/Semantics/runtime-type-info.cpp
@@ -58,7 +58,7 @@ class RuntimeTableBuilder {
void DescribeTypes(Scope &scope, bool inSchemata);
private:
- const Symbol *DescribeType(Scope &);
+ const Symbol *DescribeType(Scope &, bool wantUninstantiatedPDT);
const Symbol &GetSchemaSymbol(const char *) const;
const DeclTypeSpec &GetSchema(const char *) const;
SomeExpr GetEnumValue(const char *) const;
@@ -238,7 +238,7 @@ void RuntimeTableBuilder::DescribeTypes(Scope &scope, bool inSchemata) {
inSchemata |= ignoreScopes_.find(&scope) != ignoreScopes_.end();
if (scope.IsDerivedType()) {
if (!inSchemata) { // don't loop trying to describe a schema
- DescribeType(scope);
+ DescribeType(scope, /*wantUninstantiatedPDT=*/false);
}
} else {
scope.InstantiateDerivedTypes();
@@ -310,10 +310,10 @@ static SomeExpr StructureExpr(evaluate::StructureConstructor &&x) {
return SomeExpr{evaluate::Expr<evaluate::SomeDerived>{std::move(x)}};
}
-static int GetIntegerKind(const Symbol &symbol) {
+static int GetIntegerKind(const Symbol &symbol, bool canBeUninstantiated) {
auto dyType{evaluate::DynamicType::From(symbol)};
CHECK((dyType && dyType->category() == TypeCategory::Integer) ||
- symbol.owner().context().HasError(symbol));
+ symbol.owner().context().HasError(symbol) || canBeUninstantiated);
return dyType && dyType->category() == TypeCategory::Integer
? dyType->kind()
: symbol.owner().context().GetDefaultKind(TypeCategory::Integer);
@@ -395,7 +395,8 @@ static std::optional<std::string> GetSuffixIfTypeKindParameters(
return std::nullopt;
}
-const Symbol *RuntimeTableBuilder::DescribeType(Scope &dtScope) {
+const Symbol *RuntimeTableBuilder::DescribeType(
+ Scope &dtScope, bool wantUninstantiatedPDT) {
if (const Symbol * info{dtScope.runtimeDerivedTypeDescription()}) {
return info;
}
@@ -449,7 +450,7 @@ const Symbol *RuntimeTableBuilder::DescribeType(Scope &dtScope) {
GetSuffixIfTypeKindParameters(*derivedTypeSpec, parameters)}) {
distinctName += *suffix;
}
- } else if (isPDTDefinitionWithKindParameters) {
+ } else if (isPDTDefinitionWithKindParameters && !wantUninstantiatedPDT) {
return nullptr;
}
std::string dtDescName{(fir::kTypeDescriptorSeparator + distinctName).str()};
@@ -480,7 +481,8 @@ const Symbol *RuntimeTableBuilder::DescribeType(Scope &dtScope) {
}
if (const Symbol *
uninstDescObject{isPDTInstantiation
- ? DescribeType(DEREF(const_cast<Scope *>(dtSymbol->scope())))
+ ? DescribeType(DEREF(const_cast<Scope *>(dtSymbol->scope())),
+ /*wantUninstantiatedPDT=*/true)
: nullptr}) {
AddValue(dtValues, derivedTypeSchema_, "uninstantiated"s,
evaluate::AsGenericExpr(evaluate::Expr<evaluate::SomeDerived>{
@@ -516,7 +518,8 @@ const Symbol *RuntimeTableBuilder::DescribeType(Scope &dtScope) {
}
kinds.emplace_back(value);
} else { // LEN= parameter
- lenKinds.emplace_back(GetIntegerKind(*inst));
+ lenKinds.emplace_back(
+ GetIntegerKind(*inst, isPDTDefinitionWithKindParameters));
}
}
}
@@ -804,7 +807,9 @@ evaluate::StructureConstructor RuntimeTableBuilder::DescribeComponent(
const DerivedTypeSpec &spec{dyType.GetDerivedTypeSpec()};
Scope *derivedScope{const_cast<Scope *>(
spec.scope() ? spec.scope() : spec.typeSymbol().scope())};
- if (const Symbol * derivedDescription{DescribeType(DEREF(derivedScope))}) {
+ if (const Symbol *
+ derivedDescription{DescribeType(
+ DEREF(derivedScope), /*wantUninstantiatedPDT=*/false)}) {
AddValue(values, componentSchema_, "derived"s,
evaluate::AsGenericExpr(evaluate::Expr<evaluate::SomeDerived>{
evaluate::Designator<evaluate::SomeDerived>{
diff --git a/flang/test/Semantics/typeinfo01.f90 b/flang/test/Semantics/typeinfo01.f90
index 454d73098408d..c1427f28753cf 100644
--- a/flang/test/Semantics/typeinfo01.f90
+++ b/flang/test/Semantics/typeinfo01.f90
@@ -34,7 +34,8 @@ module m03
end type
type(kpdt(4)) :: x
!CHECK: .c.kpdt.4, SAVE, TARGET (CompilerCreated, ReadOnly): ObjectEntity type: TYPE(component) shape: 0_8:0_8 init:[component::component(name=.n.a,genre=1_1,category=2_1,kind=4_1,rank=0_1,offset=0_8,characterlen=value(genre=1_1,value=0_8),derived=NULL(),lenvalue=NULL(),bounds=NULL(),initialization=NULL())]
-!CHECK: .dt.kpdt.4, SAVE, TARGET (CompilerCreated, ReadOnly): ObjectEntity type: TYPE(derivedtype) init:derivedtype(binding=NULL(),name=.n.kpdt,sizeinbytes=4_8,uninstantiated=NULL(),kindparameter=.kp.kpdt.4,lenparameterkind=NULL(),component=.c.kpdt.4,procptr=NULL(),special=NULL(),specialbitset=0_4,hasparent=0_1,noinitializationneeded=1_1,nodestructionneeded=1_1,nofinalizationneeded=1_1)
+!CHECK: .dt.kpdt, SAVE, TARGET (CompilerCreated, ReadOnly): ObjectEntity type: TYPE(derivedtype) init:derivedtype(name=.n.kpdt,uninstantiated=NULL(),kindparameter=.kp.kpdt,lenparameterkind=NULL())
+!CHECK: .dt.kpdt.4, SAVE, TARGET (CompilerCreated, ReadOnly): ObjectEntity type: TYPE(derivedtype) init:derivedtype(binding=NULL(),name=.n.kpdt,sizeinbytes=4_8,uninstantiated=.dt.kpdt,kindparameter=.kp.kpdt.4,lenparameterkind=NULL(),component=.c.kpdt.4,procptr=NULL(),special=NULL(),specialbitset=0_4,hasparent=0_1,noinitializationneeded=1_1,nodestructionneeded=1_1,nofinalizationneeded=1_1)
!CHECK: .kp.kpdt.4, SAVE, TARGET (CompilerCreated, ReadOnly): ObjectEntity type: INTEGER(8) shape: 0_8:0_8 init:[INTEGER(8)::4_8]
end module
``````````
</details>
https://github.com/llvm/llvm-project/pull/135670
More information about the flang-commits
mailing list