[flang-commits] [flang] [flang] Enumeration Type: (PR 2/5) Name Resolution + Expression + Relational + SELECT CASE (PR #193028)
via flang-commits
flang-commits at lists.llvm.org
Wed Jun 10 08:25:13 PDT 2026
https://github.com/kwyatt-ext updated https://github.com/llvm/llvm-project/pull/193028
>From 71c3222d0c1ccb893aabfd7cbf74c813cd452f41 Mon Sep 17 00:00:00 2001
From: Kevin Wyatt <kwyatt at hpe.com>
Date: Thu, 16 Apr 2026 12:49:27 -0500
Subject: [PATCH 1/2] Enumeration Type Sem-2: Name Resolution + Expression +
Relational + SELECT CASE (PRs 3-5)
Adds name resolution for ENUMERATION TYPE (replacing the 'not yet
implemented' stub), expression analysis for enumeration constructors
with __ordinal component, relational operator support, and SELECT CASE
support.
Includes PR9 bug fixes: scope()->GetScope() in expression.cpp, and
Relate() INT() wrapping for non-constant enum comparisons in
evaluate/tools.cpp.
Files from original PRs 3-5 plus targeted fixes from PR 9.
---
flang/include/flang/Semantics/expression.h | 3 +
flang/lib/Evaluate/formatting.cpp | 8 ++
flang/lib/Evaluate/tools.cpp | 73 ++++++++++++
flang/lib/Semantics/check-case.cpp | 115 +++++++++++++++++-
flang/lib/Semantics/check-declarations.cpp | 4 +
flang/lib/Semantics/expression.cpp | 78 +++++++++++++
flang/lib/Semantics/resolve-labels.cpp | 8 ++
flang/lib/Semantics/resolve-names.cpp | 128 ++++++++++++++++++++-
flang/lib/Semantics/rewrite-parse-tree.cpp | 1 +
flang/lib/Semantics/tools.cpp | 13 ++-
flang/test/Semantics/case01.f90 | 4 +-
11 files changed, 427 insertions(+), 8 deletions(-)
diff --git a/flang/include/flang/Semantics/expression.h b/flang/include/flang/Semantics/expression.h
index f93b9a892715a..404499ada1b14 100644
--- a/flang/include/flang/Semantics/expression.h
+++ b/flang/include/flang/Semantics/expression.h
@@ -414,6 +414,9 @@ class ExpressionAnalyzer {
};
MaybeExpr CheckStructureConstructor(parser::CharBlock typeName,
const semantics::DerivedTypeSpec &, std::list<ComponentSpec> &&);
+ MaybeExpr AnalyzeEnumerationConstructor(parser::CharBlock typeName,
+ const semantics::DerivedTypeSpec &,
+ const std::list<parser::ComponentSpec> &);
MaybeExpr IterativelyAnalyzeSubexpressions(const parser::Expr &);
diff --git a/flang/lib/Evaluate/formatting.cpp b/flang/lib/Evaluate/formatting.cpp
index 3604484254196..10c834872b863 100644
--- a/flang/lib/Evaluate/formatting.cpp
+++ b/flang/lib/Evaluate/formatting.cpp
@@ -664,6 +664,14 @@ static std::string DerivedTypeSpecAsFortran(
llvm::raw_ostream &StructureConstructor::AsFortran(llvm::raw_ostream &o) const {
o << DerivedTypeSpecAsFortran(result_.derivedTypeSpec());
+ if (result_.derivedTypeSpec().IsEnumerationType()) {
+ // Print as enum_name(ordinal) without exposing the hidden __ordinal keyword
+ o << '(';
+ if (!values_.empty()) {
+ values_.begin()->second.value().AsFortran(o);
+ }
+ return o << ')';
+ }
if (values_.empty()) {
o << '(';
} else {
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 82dcd1e795f49..9f6cf6f97501a 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -755,6 +755,79 @@ std::optional<Expr<LogicalResult>> Relate(parser::ContextualMessages &messages,
},
std::move(cx.u), std::move(cy.u));
},
+ [&](Expr<SomeDerived> &&dx,
+ Expr<SomeDerived> &&dy) -> std::optional<Expr<LogicalResult>> {
+ // Enumeration type comparison: extract __ordinal and delegate
+ // to integer comparison
+ auto xType{dx.GetType()};
+ auto yType{dy.GetType()};
+ if (xType && yType) {
+ const auto *xDerived{GetDerivedTypeSpec(*xType)};
+ const auto *yDerived{GetDerivedTypeSpec(*yType)};
+ if (xDerived && yDerived && xDerived->IsEnumerationType() &&
+ yDerived->IsEnumerationType() &&
+ &xDerived->typeSymbol() == &yDerived->typeSymbol()) {
+ if (const auto *scope{xDerived->GetScope()}) {
+ auto ordIter{
+ scope->find(semantics::SourceName{"__ordinal", 9})};
+ if (ordIter != scope->end()) {
+ const semantics::Symbol &ordSym{*ordIter->second};
+ // Try to extract from Constant<SomeDerived>
+ auto extractOrdinal = [&](Expr<SomeDerived> &expr)
+ -> std::optional<Expr<SomeType>> {
+ if (auto *constant{
+ UnwrapConstantValue<SomeDerived>(expr)}) {
+ if (auto sc{constant->GetScalarValue()}) {
+ return sc->Find(ordSym);
+ }
+ } else if (auto *sc{
+ UnwrapExpr<StructureConstructor>(expr)}) {
+ return sc->Find(ordSym);
+ }
+ return std::nullopt;
+ };
+ auto xOrd{extractOrdinal(dx)};
+ auto yOrd{extractOrdinal(dy)};
+ if (xOrd && yOrd) {
+ return Relate(
+ messages, opr, std::move(*xOrd), std::move(*yOrd));
+ }
+ // Non-constant operands: wrap in INT() to convert to
+ // integer comparison. Build FunctionRef<Int4> for each
+ // operand representing INT(enumExpr).
+ auto makeIntCall =
+ [&](Expr<SomeDerived> &&operand) -> Expr<SomeType> {
+ using IntType = Type<TypeCategory::Integer, 4>;
+ DynamicType enumType{*xDerived};
+ DynamicType intResultType{TypeCategory::Integer, 4};
+ characteristics::DummyDataObject ddo{
+ characteristics::TypeAndShape{enumType}};
+ ddo.intent = common::Intent::In;
+ characteristics::Procedure::Attrs attrs;
+ attrs.set(characteristics::Procedure::Attr::Pure);
+ attrs.set(characteristics::Procedure::Attr::Elemental);
+ characteristics::DummyArguments dummies;
+ dummies.emplace_back("a"s, std::move(ddo));
+ SpecificIntrinsic intSpec{"int"s,
+ characteristics::Procedure{
+ characteristics::FunctionResult{intResultType},
+ std::move(dummies), attrs}};
+ ActualArguments intArgs;
+ intArgs.emplace_back(AsGenericExpr(std::move(operand)));
+ return AsGenericExpr(
+ Expr<SomeInteger>(Expr<IntType>(FunctionRef<IntType>{
+ ProcedureDesignator{std::move(intSpec)},
+ std::move(intArgs)})));
+ };
+ return Relate(messages, opr, makeIntCall(std::move(dx)),
+ makeIntCall(std::move(dy)));
+ }
+ }
+ }
+ }
+ DIE("invalid types for relational operator");
+ return std::optional<Expr<LogicalResult>>{};
+ },
// Default case
[&](auto &&, auto &&) {
DIE("invalid types for relational operator");
diff --git a/flang/lib/Semantics/check-case.cpp b/flang/lib/Semantics/check-case.cpp
index 9004d8b3a28f9..fa0c08c317ba1 100644
--- a/flang/lib/Semantics/check-case.cpp
+++ b/flang/lib/Semantics/check-case.cpp
@@ -11,6 +11,7 @@
#include "flang/Common/reference.h"
#include "flang/Common/template.h"
#include "flang/Evaluate/fold.h"
+#include "flang/Evaluate/tools.h"
#include "flang/Evaluate/type.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/semantics.h"
@@ -236,6 +237,103 @@ template <TypeCategory CAT> struct TypeVisitor {
const std::list<parser::CaseConstruct::Case> &caseList;
};
+// Convert a single enumeration CASE value to its __ordinal integer.
+static bool ConvertEnumCaseValue(SemanticsContext &context,
+ const parser::CaseValue &caseValue,
+ const semantics::DerivedTypeSpec &enumType,
+ const semantics::Symbol &ordSym) {
+ const auto &expr{parser::UnwrapRef<parser::Expr>(caseValue)};
+ auto *x{expr.typedExpr.get()};
+ if (!x || !x->v) {
+ return false;
+ }
+ auto type{x->v->GetType()};
+ if (!type || type->category() != TypeCategory::Derived) {
+ std::string typeStr{type ? type->AsFortran() : "typeless"s};
+ context.Say(expr.source,
+ "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US,
+ typeStr, enumType.AsFortran());
+ return false;
+ }
+ const auto *caseDerived{evaluate::GetDerivedTypeSpec(*type)};
+ if (!caseDerived || !caseDerived->IsEnumerationType() ||
+ &caseDerived->typeSymbol() != &enumType.typeSymbol()) {
+ context.Say(expr.source,
+ "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US,
+ type->AsFortran(), enumType.AsFortran());
+ return false;
+ }
+ // Extract the ordinal integer from the constant enum value
+ parser::Messages buffer;
+ parser::ContextualMessages foldingMessages{expr.source, &buffer};
+ evaluate::FoldingContext foldingContext{
+ context.foldingContext(), foldingMessages};
+ auto folded{evaluate::Fold(foldingContext, SomeExpr{*x->v})};
+ if (auto sc{
+ evaluate::GetScalarConstantValue<evaluate::SomeDerived>(folded)}) {
+ if (auto ordExpr{sc->Find(ordSym)}) {
+ x->v = std::move(*ordExpr);
+ return true;
+ }
+ }
+ context.Say(expr.source,
+ "CASE value (%s) must be a constant scalar"_err_en_US, x->v->AsFortran());
+ return false;
+}
+
+// Walk all CASE values in an enumeration SELECT CASE, check type
+// compatibility, and convert each to its ordinal integer value.
+static bool ConvertEnumCaseValues(SemanticsContext &context,
+ const std::list<parser::CaseConstruct::Case> &cases,
+ const semantics::DerivedTypeSpec &enumType) {
+ const auto *scope{enumType.GetScope()};
+ if (!scope) {
+ return false;
+ }
+ auto ordIter{scope->find(semantics::SourceName{"__ordinal", 9})};
+ if (ordIter == scope->end()) {
+ return false;
+ }
+ const semantics::Symbol &ordSym{*ordIter->second};
+ bool ok{true};
+ for (const auto &c : cases) {
+ const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)};
+ const auto &selector{std::get<parser::CaseSelector>(stmt.statement.t)};
+ common::visit(common::visitors{
+ [&](const std::list<parser::CaseValueRange> &ranges) {
+ for (const auto &range : ranges) {
+ common::visit(
+ common::visitors{
+ [&](const parser::CaseValue &val) {
+ if (!ConvertEnumCaseValue(
+ context, val, enumType, ordSym)) {
+ ok = false;
+ }
+ },
+ [&](const parser::CaseValueRange::Range &r) {
+ const auto &[lower, upper]{r.t};
+ if (lower &&
+ !ConvertEnumCaseValue(context, *lower,
+ enumType, ordSym)) {
+ ok = false;
+ }
+ if (upper &&
+ !ConvertEnumCaseValue(context, *upper,
+ enumType, ordSym)) {
+ ok = false;
+ }
+ },
+ },
+ range.u);
+ }
+ },
+ [](const parser::Default &) {},
+ },
+ selector.u);
+ }
+ return ok;
+}
+
void CaseChecker::Enter(const parser::CaseConstruct &construct) {
const auto &selectCaseStmt{
std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)};
@@ -266,13 +364,26 @@ void CaseChecker::Enter(const parser::CaseConstruct &construct) {
common::SearchTypes(
TypeVisitor<TypeCategory::Character>{context_, *exprType, caseList});
return;
+ case TypeCategory::Derived:
+ if (const auto *derived{evaluate::GetDerivedTypeSpec(*exprType)}) {
+ if (derived->IsEnumerationType()) {
+ if (ConvertEnumCaseValues(context_, caseList, *derived)) {
+ evaluate::DynamicType intType{TypeCategory::Integer, 4};
+ CaseValues<evaluate::Type<TypeCategory::Integer, 4>>{
+ context_, intType}
+ .Check(caseList);
+ }
+ return;
+ }
+ }
+ break;
default:
break;
}
}
context_.Say(selectExpr.source,
context_.IsEnabled(common::LanguageFeature::Unsigned)
- ? "SELECT CASE expression must be integer, unsigned, logical, or character"_err_en_US
- : "SELECT CASE expression must be integer, logical, or character"_err_en_US);
+ ? "SELECT CASE expression must be integer, unsigned, logical, character, or enumeration type"_err_en_US
+ : "SELECT CASE expression must be integer, logical, character, or enumeration type"_err_en_US);
}
} // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp
index 6a2fd40aeec79..a11b6bc4d691c 100644
--- a/flang/lib/Semantics/check-declarations.cpp
+++ b/flang/lib/Semantics/check-declarations.cpp
@@ -1841,6 +1841,10 @@ void CheckHelper::CheckExternal(const Symbol &symbol) {
void CheckHelper::CheckDerivedType(
const Symbol &derivedType, const DerivedTypeDetails &details) {
+ if (details.isEnumerationType()) {
+ // Enumeration types have no components, parameters, or bindings to check.
+ return;
+ }
if (details.isForwardReferenced() && !context_.HasError(derivedType)) {
messages_.Say("The derived type '%s' has not been defined"_err_en_US,
derivedType.name());
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index 50869a3c870ef..26bd9fa3b382e 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -1482,6 +1482,12 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::StructureComponent &sc) {
const auto &name{sc.Component().source};
if (auto *dtExpr{UnwrapExpr<Expr<SomeDerived>>(*base)}) {
const auto *dtSpec{GetDerivedTypeSpec(dtExpr->GetType())};
+ if (dtSpec && dtSpec->IsEnumerationType()) {
+ Say(name,
+ "Component reference is not allowed for enumeration type '%s'"_err_en_US,
+ dtSpec->typeSymbol().name());
+ return std::nullopt;
+ }
if (isTypeParamInquiry) {
if (auto *designator{UnwrapExpr<Designator<SomeDerived>>(*dtExpr)}) {
if (std::optional<DynamicType> dyType{DynamicType::From(*sym)}) {
@@ -2462,6 +2468,65 @@ MaybeExpr ExpressionAnalyzer::CheckStructureConstructor(
return AsMaybeExpr(Expr<SomeDerived>{std::move(result)});
}
+// F2023 R771: enumeration-constructor is enumeration-type-spec (
+// scalar-int-expr ) The scalar-int-expr shall have a value that is positive and
+// less than or equal to the number of enumerators in the enumeration type.
+MaybeExpr ExpressionAnalyzer::AnalyzeEnumerationConstructor(
+ parser::CharBlock typeName, const semantics::DerivedTypeSpec &spec,
+ const std::list<parser::ComponentSpec> &components) {
+ const semantics::Symbol &typeSymbol{spec.typeSymbol()};
+ const auto &typeDetails{typeSymbol.get<semantics::DerivedTypeDetails>()};
+ int enumeratorCount{typeDetails.enumeratorCount()};
+ // Validate: exactly one positional argument, no keywords
+ if (components.size() != 1) {
+ Say(typeName,
+ "Enumeration constructor for '%s' requires exactly one argument"_err_en_US,
+ typeName);
+ return std::nullopt;
+ }
+ const auto &component{components.front()};
+ if (std::get<std::optional<parser::Keyword>>(component.t)) {
+ Say(typeName,
+ "Enumeration constructor for '%s' may not have a keyword argument"_err_en_US,
+ typeName);
+ return std::nullopt;
+ }
+ // Analyze the argument as a scalar integer expression
+ const parser::Expr &argExpr{
+ std::get<parser::ComponentDataSource>(component.t).v.value()};
+ auto restorer{GetContextualMessages().SetLocation(argExpr.source)};
+ MaybeExpr analyzed{Analyze(argExpr)};
+ if (!analyzed) {
+ return std::nullopt;
+ }
+ auto folded{Fold(std::move(*analyzed))};
+ auto argType{folded.GetType()};
+ if (!argType || argType->category() != TypeCategory::Integer) {
+ Say(argExpr.source,
+ "Enumeration constructor argument must be INTEGER, but is %s"_err_en_US,
+ argType ? argType->AsFortran() : std::string{"typeless"});
+ return std::nullopt;
+ }
+ // If the value is known at compile time, validate the range
+ if (auto value{ToInt64(folded)}) {
+ if (*value < 1 || *value > enumeratorCount) {
+ Say(argExpr.source,
+ "Enumeration constructor value (%jd) for '%s' must be positive and less than or equal to the number of enumerators (%d)"_err_en_US,
+ static_cast<std::intmax_t>(*value), typeName, enumeratorCount);
+ return std::nullopt;
+ }
+ }
+ // Produce an Expr<SomeDerived> with the ordinal in the __ordinal component
+ StructureConstructor result{spec};
+ if (const auto *scope{spec.GetScope()}) {
+ auto ordinalIter{scope->find(semantics::SourceName{"__ordinal", 9})};
+ if (ordinalIter != scope->end()) {
+ result.Add(*ordinalIter->second, std::move(folded));
+ }
+ }
+ return AsMaybeExpr(Expr<SomeDerived>{std::move(result)});
+}
+
MaybeExpr ExpressionAnalyzer::Analyze(
const parser::StructureConstructor &structure) {
const auto &parsedType{std::get<parser::DerivedTypeSpec>(structure.t)};
@@ -2478,6 +2543,11 @@ MaybeExpr ExpressionAnalyzer::Analyze(
if (!parsedType.derivedTypeSpec) {
return std::nullopt;
}
+ // F2023 R771: Enumeration constructor — enum_name(scalar-int-expr)
+ if (parsedType.derivedTypeSpec->IsEnumerationType()) {
+ return AnalyzeEnumerationConstructor(typeName, *parsedType.derivedTypeSpec,
+ std::get<std::list<parser::ComponentSpec>>(structure.t));
+ }
auto restorer{AllowNullPointer()}; // NULL() can be a valid component
std::list<ComponentSpec> componentSpecs;
for (const auto &component :
@@ -3569,6 +3639,14 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::FunctionReference &funcRef,
if (!CheckIsValidForwardReference(dtSpec)) {
return std::nullopt;
}
+ // Detect enumeration types and set the category accordingly
+ if (const auto *dtDetails{
+ symbol.detailsIf<semantics::DerivedTypeDetails>()}) {
+ if (dtDetails->isEnumerationType()) {
+ dtSpec.set_category(
+ semantics::DerivedTypeSpec::Category::EnumerationType);
+ }
+ }
const semantics::DeclTypeSpec &type{
semantics::FindOrInstantiateDerivedType(scope, std::move(dtSpec))};
auto &mutableRef{const_cast<parser::FunctionReference &>(funcRef)};
diff --git a/flang/lib/Semantics/resolve-labels.cpp b/flang/lib/Semantics/resolve-labels.cpp
index 2da42b2f26cb1..3449b3d34f903 100644
--- a/flang/lib/Semantics/resolve-labels.cpp
+++ b/flang/lib/Semantics/resolve-labels.cpp
@@ -559,6 +559,14 @@ class ParseTreeAnalyzer {
PopDisposableMap();
}
+ // F2023 C7115
+ void Post(const parser::EnumerationTypeDef &enumTypeDef) {
+ CheckOptionalName<parser::EnumerationTypeStmt>(
+ "enumeration type definition", enumTypeDef,
+ std::get<parser::Statement<parser::EndEnumerationTypeStmt>>(
+ enumTypeDef.t));
+ }
+
void Post(const parser::LabelDoStmt &labelDoStmt) {
AddLabelReferenceFromDoStmt(std::get<parser::Label>(labelDoStmt.t));
}
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 2a2073f29a26e..2db60f7e84d9f 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -1001,6 +1001,9 @@ class DeclarationVisitor : public ArraySpecVisitor,
void Post(const parser::EnumDef &);
bool Pre(const parser::Enumerator &);
bool Pre(const parser::EnumerationTypeDef &);
+ void Post(const parser::EnumerationTypeStmt &);
+ bool Pre(const parser::EnumerationEnumeratorStmt &);
+ void Post(const parser::EndEnumerationTypeStmt &);
bool Pre(const parser::AccessSpec &);
bool Pre(const parser::AsynchronousStmt &);
bool Pre(const parser::ContiguousStmt &);
@@ -6011,14 +6014,93 @@ bool DeclarationVisitor::Pre(const parser::Enumerator &enumerator) {
return false;
}
+void DeclarationVisitor::Post(const parser::EnumDef &) {
+ enumerationState_ = EnumeratorState{};
+}
+
+// F2023 R766 EnumerationTypeDef — scope is pushed in Post(EnumerationTypeStmt)
+// and popped in Post(EndEnumerationTypeStmt).
bool DeclarationVisitor::Pre(const parser::EnumerationTypeDef &x) {
+ BeginAttrs();
+ // TODO: Remove this and set true when ENUMERATION TYPEs are implemented.
Say(std::get<parser::Statement<parser::EnumerationTypeStmt>>(x.t).source,
"F2023 ENUMERATION TYPEs are not yet implemented"_err_en_US);
return false;
}
-void DeclarationVisitor::Post(const parser::EnumDef &) {
- enumerationState_ = EnumeratorState{};
+// F2023 R767 EnumerationTypeStmt — create the enumeration type symbol
+// in the enclosing scope and push a DerivedType scope for it.
+void DeclarationVisitor::Post(const parser::EnumerationTypeStmt &x) {
+ const auto &name{std::get<parser::Name>(x.t)};
+ Attrs attrs{EndAttrs()};
+ if (const auto &optAccessSpec{
+ std::get<std::optional<parser::AccessSpec>>(x.t)};
+ optAccessSpec) {
+ if (!NonDerivedTypeScope().IsModule()) { // F2023 C7114
+ Say(currStmtSource().value(),
+ "Access specifier on ENUMERATION TYPE may only appear in the specification part of a module"_err_en_US);
+ }
+ }
+ DerivedTypeDetails details;
+ details.set_isEnumerationType(true);
+ auto &symbol{MakeSymbol(name, attrs, std::move(details))};
+ symbol.ReplaceName(name.source);
+ PushScope(Scope::Kind::DerivedType, &symbol);
+ // Add a hidden __ordinal component to hold the 1-based enumerator position.
+ // This is a compiler-created INTEGER(4) component that preserves ordinal
+ // identity through constant folding and enables enumerator comparison.
+ SourceName ordinalName{context().SaveTempName(std::string{"__ordinal"})};
+ Symbol &ordinalSym{MakeSymbol(currScope(), ordinalName, Attrs{})};
+ ordinalSym.set_details(ObjectEntityDetails{});
+ ordinalSym.SetType(
+ currScope().MakeNumericType(TypeCategory::Integer, KindExpr{4}));
+ ordinalSym.set(Symbol::Flag::CompilerCreated);
+ symbol.get<DerivedTypeDetails>().add_component(ordinalSym);
+}
+
+// F2023 R768 EnumerationEnumeratorStmt — create PARAMETER symbols for
+// each enumerator name in the enclosing scope with 1-based ordinal init.
+bool DeclarationVisitor::Pre(const parser::EnumerationEnumeratorStmt &x) {
+ Scope &enclosingScope{NonDerivedTypeScope()};
+ // The current DerivedType scope's symbol is the enumeration type.
+ Symbol *typeSymbol{currScope().symbol()};
+ CHECK(typeSymbol);
+ auto &typeDetails{typeSymbol->get<DerivedTypeDetails>()};
+ // Build a DerivedTypeSpec for the enumeration type.
+ DerivedTypeSpec enumTypeSpec{typeSymbol->name(), *typeSymbol};
+ enumTypeSpec.set_category(DerivedTypeSpec::Category::EnumerationType);
+ DeclTypeSpec &declType{enclosingScope.MakeDerivedType(
+ DeclTypeSpec::TypeDerived, std::move(enumTypeSpec))};
+ for (const parser::Name &name : x.v) {
+ int ordinal{typeDetails.enumeratorCount() + 1};
+ // Create the enumerator symbol in the enclosing scope, not the
+ // enumeration type's own DerivedType scope.
+ Symbol &enumerator{
+ MakeSymbol(enclosingScope, name.source, Attrs{Attr::PARAMETER})};
+ Resolve(name, enumerator);
+ enumerator.set_details(ObjectEntityDetails{});
+ enumerator.SetType(declType);
+ // Store the init as a StructureConstructor of the enumeration type with
+ // the ordinal in the hidden __ordinal component. This gives each
+ // enumerator a distinct Constant<SomeDerived> value.
+ evaluate::StructureConstructor enumCtor{declType.derivedTypeSpec()};
+ // Look up the __ordinal component symbol in the type's scope.
+ auto ordinalIter{currScope().find(SourceName{"__ordinal", 9})};
+ CHECK(ordinalIter != currScope().end());
+ const Symbol &ordinalSym{*ordinalIter->second};
+ enumCtor.Add(ordinalSym,
+ evaluate::AsGenericExpr(evaluate::Expr<evaluate::CInteger>{ordinal}));
+ enumerator.get<ObjectEntityDetails>().set_init(
+ SomeExpr{evaluate::Expr<evaluate::SomeDerived>{
+ evaluate::Constant<evaluate::SomeDerived>{std::move(enumCtor)}}});
+ typeDetails.set_enumeratorCount(ordinal);
+ }
+ return false;
+}
+
+// F2023 R769 EndEnumerationTypeStmt — pop the scope.
+void DeclarationVisitor::Post(const parser::EndEnumerationTypeStmt &) {
+ PopScope();
}
bool DeclarationVisitor::Pre(const parser::AccessSpec &x) {
@@ -6610,6 +6692,17 @@ void DeclarationVisitor::Post(const parser::DerivedTypeSpec &x) {
// in the current scope, this spec will be moved into that collection.
const auto &dtDetails{spec->typeSymbol().get<DerivedTypeDetails>()};
auto category{GetDeclTypeSpecCategory()};
+
+ // Enumeration types are a special case of derived types and are handled
+ // differently.
+ if (dtDetails.isEnumerationType()) {
+ spec->set_category(DerivedTypeSpec::Category::EnumerationType);
+ DeclTypeSpec &type{currScope().MakeDerivedType(category, std::move(*spec))};
+ SetDeclTypeSpec(type);
+ x.derivedTypeSpec = &GetDeclTypeSpec()->derivedTypeSpec();
+ return;
+ }
+
if (dtDetails.isForwardReferenced()) {
DeclTypeSpec &type{currScope().MakeDerivedType(category, std::move(*spec))};
SetDeclTypeSpec(type);
@@ -8930,6 +9023,12 @@ class ExecutionPartSkimmerBase {
return true;
}
void Post(const parser::DerivedTypeDef &) { PopScope(); }
+ bool Pre(const parser::EnumerationTypeStmt &x) {
+ Hide(std::get<parser::Name>(x.t));
+ PushScope();
+ return true;
+ }
+ void Post(const parser::EnumerationTypeDef &) { PopScope(); }
bool Pre(const parser::SelectTypeConstruct &) {
PushScope();
return true;
@@ -9404,6 +9503,12 @@ const parser::Name *DeclarationVisitor::FindComponent(
return &component;
}
} else if (DerivedTypeSpec * derived{type->AsDerived()}) {
+ if (derived->IsEnumerationType()) {
+ Say(component.source,
+ "Component reference is not allowed for enumeration type '%s'"_err_en_US,
+ derived->typeSymbol().name());
+ return nullptr;
+ }
derived->Instantiate(currScope()); // in case of forward referenced type
if (const Scope * scope{derived->scope()}) {
if (Resolve(component, scope->FindComponent(component.source))) {
@@ -11099,6 +11204,25 @@ class DeferredCheckVisitor {
}
}
+ void Post(const parser::EnumerationTypeStmt &x) {
+ const auto &name{std::get<parser::Name>(x.t)};
+ if (Symbol * symbol{name.symbol}) {
+ if (Scope * scope{symbol->scope()}) {
+ if (scope->IsDerivedType()) {
+ CHECK(outerScope_ == nullptr);
+ outerScope_ = &resolver_.currScope();
+ resolver_.SetScope(*scope);
+ }
+ }
+ }
+ }
+ void Post(const parser::EndEnumerationTypeStmt &) {
+ if (outerScope_) {
+ resolver_.SetScope(*outerScope_);
+ outerScope_ = nullptr;
+ }
+ }
+
void Post(const parser::ProcInterface &pi) {
if (const auto *name{std::get_if<parser::Name>(&pi.u)}) {
resolver_.CheckExplicitInterface(*name);
diff --git a/flang/lib/Semantics/rewrite-parse-tree.cpp b/flang/lib/Semantics/rewrite-parse-tree.cpp
index 4e1c9bae9c153..fd323cbb0177c 100644
--- a/flang/lib/Semantics/rewrite-parse-tree.cpp
+++ b/flang/lib/Semantics/rewrite-parse-tree.cpp
@@ -81,6 +81,7 @@ class RewriteMutator {
bool Pre(parser::EndSubmoduleStmt &) { return false; }
bool Pre(parser::EndSubroutineStmt &) { return false; }
bool Pre(parser::EndTypeStmt &) { return false; }
+ bool Pre(parser::EndEnumerationTypeStmt &) { return false; }
bool Pre(parser::OmpBlockConstruct &);
bool Pre(parser::OpenMPLoopConstruct &);
diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp
index 79511c93b79b4..c965dc0d1c32d 100644
--- a/flang/lib/Semantics/tools.cpp
+++ b/flang/lib/Semantics/tools.cpp
@@ -182,9 +182,18 @@ bool IsIntrinsicRelational(common::RelationalOperator opr,
return opr == common::RelationalOperator::EQ ||
opr == common::RelationalOperator::NE ||
(cat0 != TypeCategory::Complex && cat1 != TypeCategory::Complex);
+ } else if (cat0 == TypeCategory::Character &&
+ cat1 == TypeCategory::Character) {
+ return true;
+ } else if (cat0 == TypeCategory::Derived && cat1 == TypeCategory::Derived) {
+ // Same enumeration type: all six relational operators are allowed
+ const auto *derived0{evaluate::GetDerivedTypeSpec(type0)};
+ const auto *derived1{evaluate::GetDerivedTypeSpec(type1)};
+ return derived0 && derived1 && derived0->IsEnumerationType() &&
+ derived1->IsEnumerationType() &&
+ &derived0->typeSymbol() == &derived1->typeSymbol();
} else {
- // not both numeric: only Character is ok
- return cat0 == TypeCategory::Character && cat1 == TypeCategory::Character;
+ return false;
}
}
}
diff --git a/flang/test/Semantics/case01.f90 b/flang/test/Semantics/case01.f90
index c9631d299e49c..7caa453ef6252 100644
--- a/flang/test/Semantics/case01.f90
+++ b/flang/test/Semantics/case01.f90
@@ -45,7 +45,7 @@ program selectCaseProg
end select
! C1145
- !ERROR: SELECT CASE expression must be integer, logical, or character
+ !ERROR: SELECT CASE expression must be integer, logical, character, or enumeration type
select case (grade4)
case (1.0)
case (2.0)
@@ -53,7 +53,7 @@ program selectCaseProg
case default
end select
- !ERROR: SELECT CASE expression must be integer, logical, or character
+ !ERROR: SELECT CASE expression must be integer, logical, character, or enumeration type
select case (score)
case (score_val)
case (scores(100))
>From d704a6ef5f9001e226709e2b5535dfbe08e768d8 Mon Sep 17 00:00:00 2001
From: Kevin Wyatt <kwyatt at hpe.com>
Date: Mon, 20 Apr 2026 11:13:28 -0500
Subject: [PATCH 2/2] Adding tests and intrinsic piece required for
relationals.
---
flang/lib/Evaluate/fold-integer.cpp | 28 +++++
flang/lib/Semantics/resolve-names.cpp | 5 +-
.../enumeration-type-declarations.f90 | 84 +++++++++++++
.../Semantics/enumeration-type-relational.f90 | 117 ++++++++++++++++++
4 files changed, 230 insertions(+), 4 deletions(-)
create mode 100644 flang/test/Semantics/enumeration-type-declarations.f90
create mode 100644 flang/test/Semantics/enumeration-type-relational.f90
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index 9f2bb94a9213f..d5dcf272d53d7 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -761,6 +761,34 @@ std::optional<Expr<T>> FoldIntrinsicFunctionCommon(
} else if (name == "int" || name == "int2" || name == "int8" ||
name == "uint") {
if (auto *expr{UnwrapExpr<Expr<SomeType>>(args[0])}) {
+ // Check for enumeration type argument first — extract __ordinal
+ if (auto *derivedExpr{std::get_if<Expr<SomeDerived>>(&expr->u)}) {
+ if (auto type{derivedExpr->GetType()}) {
+ if (const auto *derived{GetDerivedTypeSpec(*type)}) {
+ if (derived->IsEnumerationType()) {
+ if (const auto *scope{derived->GetScope()}) {
+ auto ordIter{
+ scope->find(semantics::SourceName{"__ordinal", 9})};
+ if (ordIter != scope->end()) {
+ const semantics::Symbol &ordSym{*ordIter->second};
+ if (auto *constant{
+ UnwrapConstantValue<SomeDerived>(*derivedExpr)}) {
+ if (auto sc{constant->GetScalarValue()}) {
+ if (auto ordExpr{sc->Find(ordSym)}) {
+ if (auto ordVal{ToInt64(*ordExpr)}) {
+ return Expr<T>{Constant<T>{Scalar<T>{*ordVal}}};
+ }
+ }
+ }
+ }
+ }
+ }
+ // Non-constant enumeration argument — leave unfolded
+ return Expr<T>{std::move(funcRef)};
+ }
+ }
+ }
+ }
return common::visit(
[&](auto &&x) -> Expr<T> {
using From = std::decay_t<decltype(x)>;
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 2db60f7e84d9f..c1e57e9dedc32 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -6022,10 +6022,7 @@ void DeclarationVisitor::Post(const parser::EnumDef &) {
// and popped in Post(EndEnumerationTypeStmt).
bool DeclarationVisitor::Pre(const parser::EnumerationTypeDef &x) {
BeginAttrs();
- // TODO: Remove this and set true when ENUMERATION TYPEs are implemented.
- Say(std::get<parser::Statement<parser::EnumerationTypeStmt>>(x.t).source,
- "F2023 ENUMERATION TYPEs are not yet implemented"_err_en_US);
- return false;
+ return true;
}
// F2023 R767 EnumerationTypeStmt — create the enumeration type symbol
diff --git a/flang/test/Semantics/enumeration-type-declarations.f90 b/flang/test/Semantics/enumeration-type-declarations.f90
new file mode 100644
index 0000000000000..de66ae888268e
--- /dev/null
+++ b/flang/test/Semantics/enumeration-type-declarations.f90
@@ -0,0 +1,84 @@
+! RUN: %python %S/test_errors.py %s %flang_fc1
+! Test declaration, constructor, and expression semantics for enumeration types
+
+! C7114: access specifier only allowed in module
+subroutine test_access_specifier_outside_module()
+ !ERROR: PRIVATE attribute may only appear in the specification part of a module
+ !ERROR: Access specifier on ENUMERATION TYPE may only appear in the specification part of a module
+ enumeration type, private :: color
+ enumerator :: red, green, blue
+ end enumeration type
+end subroutine
+
+! Valid: basic declarations and usage
+subroutine test_basic_declarations()
+ enumeration type :: color
+ enumerator :: red, green, blue
+ end enumeration type
+
+ type(color) :: c1, c2
+ logical :: l
+
+ ! Valid: assign an enumerator
+ c1 = red
+ c2 = blue
+
+ ! Valid: comparison produces logical
+ l = (c1 == c2)
+ l = (c1 /= red)
+end subroutine
+
+! Valid: constructor syntax — color(n) where n is a positive integer <= count
+subroutine test_constructor_valid()
+ enumeration type :: color
+ enumerator :: red, green, blue
+ end enumeration type
+
+ type(color) :: c
+
+ ! Valid: integer constructor in range
+ c = color(1)
+ c = color(2)
+ c = color(3)
+end subroutine
+
+! Constructor errors
+subroutine test_constructor_errors()
+ enumeration type :: color
+ enumerator :: red, green, blue
+ end enumeration type
+
+ type(color) :: c
+
+ ! ERROR: Enumeration constructor for 'color' requires exactly one argument
+ c = color()
+
+ ! ERROR: Enumeration constructor for 'color' requires exactly one argument
+ c = color(1, 2)
+
+ ! ERROR: Enumeration constructor for 'color' may not have a keyword argument
+ c = color(val=1)
+
+ ! ERROR: Enumeration constructor argument must be INTEGER, but is REAL(4)
+ c = color(1.0)
+
+ ! ERROR: Enumeration constructor value (0) for 'color' must be positive and less than or equal to the number of enumerators (3)
+ c = color(0)
+
+ ! ERROR: Enumeration constructor value (4) for 'color' must be positive and less than or equal to the number of enumerators (3)
+ c = color(4)
+end subroutine
+
+! Component reference on enumeration type is not allowed
+subroutine test_component_reference()
+ enumeration type :: color
+ enumerator :: red, green, blue
+ end enumeration type
+
+ type(color) :: c
+ integer :: i
+
+ c = red
+ ! ERROR: Component reference is not allowed for enumeration type 'color'
+ i = c%__ordinal
+end subroutine
diff --git a/flang/test/Semantics/enumeration-type-relational.f90 b/flang/test/Semantics/enumeration-type-relational.f90
new file mode 100644
index 0000000000000..507635c6bbdd1
--- /dev/null
+++ b/flang/test/Semantics/enumeration-type-relational.f90
@@ -0,0 +1,117 @@
+! RUN: %python %S/test_errors.py %s %flang_fc1
+! Test relational operators and SELECT CASE for enumeration types (F2023 7.6.2)
+
+module enum_mod
+ enumeration type :: color
+ enumerator :: red, green, blue
+ end enumeration type
+
+ enumeration type :: direction
+ enumerator :: north, south, east, west
+ end enumeration type
+
+ enumeration type :: w_value
+ enumerator :: w1, w2, w3, w4, w5
+ end enumeration type
+end module
+
+subroutine test_relational_same_type()
+ use enum_mod
+ logical :: result
+
+ ! Valid: all six relational operators between same-type enumerators
+ result = red == red
+ result = red /= green
+ result = red < green
+ result = green > red
+ result = red <= red
+ result = blue >= green
+end subroutine
+
+subroutine test_relational_cross_type()
+ use enum_mod
+
+ ! ERROR: Operands of .EQ. must have comparable types; have TYPE(color) and TYPE(direction)
+ if (red == north) stop 1
+
+ ! ERROR: Operands of .LT. must have comparable types; have TYPE(color) and TYPE(direction)
+ if (red < north) stop 2
+end subroutine
+
+subroutine test_relational_enum_vs_integer()
+ use enum_mod
+
+ ! ERROR: Operands of .EQ. must have comparable types; have TYPE(color) and INTEGER(4)
+ if (red == 1) stop 1
+
+ ! ERROR: Operands of .EQ. must have comparable types; have INTEGER(4) and TYPE(color)
+ if (1 == red) stop 2
+end subroutine
+
+subroutine test_select_case_basic(w)
+ use enum_mod
+ type(w_value), intent(in) :: w
+
+ ! Valid: SELECT CASE with enumerator names as case values
+ select case (w)
+ case (w1)
+ print *, 'w1'
+ case (w2)
+ print *, 'w2'
+ case default
+ print *, 'other'
+ end select
+end subroutine
+
+subroutine test_select_case_range(w)
+ use enum_mod
+ type(w_value), intent(in) :: w
+
+ ! Valid: SELECT CASE with ranges
+ select case (w)
+ case (w1)
+ print *, 'w1'
+ case (w2:w4)
+ print *, 'w2 to w4'
+ case (w5)
+ print *, 'w5'
+ end select
+end subroutine
+
+subroutine test_select_case_wrong_enum(w)
+ use enum_mod
+ type(w_value), intent(in) :: w
+
+ select case (w)
+ !ERROR: CASE value has type 'color' which is not compatible with the SELECT CASE expression's type 'ENUMERATION TYPE :: w_value'
+ case (red)
+ print *, 'wrong'
+ case default
+ print *, 'ok'
+ end select
+end subroutine
+
+subroutine test_select_case_integer_case(w)
+ use enum_mod
+ type(w_value), intent(in) :: w
+
+ select case (w)
+ !ERROR: CASE value has type 'INTEGER(4)' which is not compatible with the SELECT CASE expression's type 'ENUMERATION TYPE :: w_value'
+ case (1)
+ print *, 'wrong'
+ case default
+ print *, 'ok'
+ end select
+end subroutine
+
+subroutine test_select_case_non_enum_derived()
+ type :: my_type
+ integer :: val
+ end type
+ type(my_type) :: x = my_type(1)
+
+ !ERROR: SELECT CASE expression must be integer, logical, character, or enumeration type
+ select case (x)
+ case default
+ end select
+end subroutine
More information about the flang-commits
mailing list