[flang-commits] [flang] [flang] Enumeration Type: (PR 2/5) Name Resolution + Expression + Relational + SELECT CASE (PR #193028)

via flang-commits flang-commits at lists.llvm.org
Tue Jun 23 14:58:25 PDT 2026


https://github.com/kwyatt-ext updated https://github.com/llvm/llvm-project/pull/193028

>From 71c3222d0c1ccb893aabfd7cbf74c813cd452f41 Mon Sep 17 00:00:00 2001
From: Kevin Wyatt <kwyatt at hpe.com>
Date: Thu, 16 Apr 2026 12:49:27 -0500
Subject: [PATCH 1/4] Enumeration Type Sem-2: Name Resolution + Expression +
 Relational + SELECT CASE (PRs 3-5)

Adds name resolution for ENUMERATION TYPE (replacing the 'not yet
implemented' stub), expression analysis for enumeration constructors
with __ordinal component, relational operator support, and SELECT CASE
support.

Includes PR9 bug fixes: scope()->GetScope() in expression.cpp, and
Relate() INT() wrapping for non-constant enum comparisons in
evaluate/tools.cpp.

Files from original PRs 3-5 plus targeted fixes from PR 9.
---
 flang/include/flang/Semantics/expression.h |   3 +
 flang/lib/Evaluate/formatting.cpp          |   8 ++
 flang/lib/Evaluate/tools.cpp               |  73 ++++++++++++
 flang/lib/Semantics/check-case.cpp         | 115 +++++++++++++++++-
 flang/lib/Semantics/check-declarations.cpp |   4 +
 flang/lib/Semantics/expression.cpp         |  78 +++++++++++++
 flang/lib/Semantics/resolve-labels.cpp     |   8 ++
 flang/lib/Semantics/resolve-names.cpp      | 128 ++++++++++++++++++++-
 flang/lib/Semantics/rewrite-parse-tree.cpp |   1 +
 flang/lib/Semantics/tools.cpp              |  13 ++-
 flang/test/Semantics/case01.f90            |   4 +-
 11 files changed, 427 insertions(+), 8 deletions(-)

diff --git a/flang/include/flang/Semantics/expression.h b/flang/include/flang/Semantics/expression.h
index f93b9a892715a..404499ada1b14 100644
--- a/flang/include/flang/Semantics/expression.h
+++ b/flang/include/flang/Semantics/expression.h
@@ -414,6 +414,9 @@ class ExpressionAnalyzer {
   };
   MaybeExpr CheckStructureConstructor(parser::CharBlock typeName,
       const semantics::DerivedTypeSpec &, std::list<ComponentSpec> &&);
+  MaybeExpr AnalyzeEnumerationConstructor(parser::CharBlock typeName,
+      const semantics::DerivedTypeSpec &,
+      const std::list<parser::ComponentSpec> &);
 
   MaybeExpr IterativelyAnalyzeSubexpressions(const parser::Expr &);
 
diff --git a/flang/lib/Evaluate/formatting.cpp b/flang/lib/Evaluate/formatting.cpp
index 3604484254196..10c834872b863 100644
--- a/flang/lib/Evaluate/formatting.cpp
+++ b/flang/lib/Evaluate/formatting.cpp
@@ -664,6 +664,14 @@ static std::string DerivedTypeSpecAsFortran(
 
 llvm::raw_ostream &StructureConstructor::AsFortran(llvm::raw_ostream &o) const {
   o << DerivedTypeSpecAsFortran(result_.derivedTypeSpec());
+  if (result_.derivedTypeSpec().IsEnumerationType()) {
+    // Print as enum_name(ordinal) without exposing the hidden __ordinal keyword
+    o << '(';
+    if (!values_.empty()) {
+      values_.begin()->second.value().AsFortran(o);
+    }
+    return o << ')';
+  }
   if (values_.empty()) {
     o << '(';
   } else {
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 82dcd1e795f49..9f6cf6f97501a 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -755,6 +755,79 @@ std::optional<Expr<LogicalResult>> Relate(parser::ContextualMessages &messages,
                 },
                 std::move(cx.u), std::move(cy.u));
           },
+          [&](Expr<SomeDerived> &&dx,
+              Expr<SomeDerived> &&dy) -> std::optional<Expr<LogicalResult>> {
+            // Enumeration type comparison: extract __ordinal and delegate
+            // to integer comparison
+            auto xType{dx.GetType()};
+            auto yType{dy.GetType()};
+            if (xType && yType) {
+              const auto *xDerived{GetDerivedTypeSpec(*xType)};
+              const auto *yDerived{GetDerivedTypeSpec(*yType)};
+              if (xDerived && yDerived && xDerived->IsEnumerationType() &&
+                  yDerived->IsEnumerationType() &&
+                  &xDerived->typeSymbol() == &yDerived->typeSymbol()) {
+                if (const auto *scope{xDerived->GetScope()}) {
+                  auto ordIter{
+                      scope->find(semantics::SourceName{"__ordinal", 9})};
+                  if (ordIter != scope->end()) {
+                    const semantics::Symbol &ordSym{*ordIter->second};
+                    // Try to extract from Constant<SomeDerived>
+                    auto extractOrdinal = [&](Expr<SomeDerived> &expr)
+                        -> std::optional<Expr<SomeType>> {
+                      if (auto *constant{
+                              UnwrapConstantValue<SomeDerived>(expr)}) {
+                        if (auto sc{constant->GetScalarValue()}) {
+                          return sc->Find(ordSym);
+                        }
+                      } else if (auto *sc{
+                                     UnwrapExpr<StructureConstructor>(expr)}) {
+                        return sc->Find(ordSym);
+                      }
+                      return std::nullopt;
+                    };
+                    auto xOrd{extractOrdinal(dx)};
+                    auto yOrd{extractOrdinal(dy)};
+                    if (xOrd && yOrd) {
+                      return Relate(
+                          messages, opr, std::move(*xOrd), std::move(*yOrd));
+                    }
+                    // Non-constant operands: wrap in INT() to convert to
+                    // integer comparison. Build FunctionRef<Int4> for each
+                    // operand representing INT(enumExpr).
+                    auto makeIntCall =
+                        [&](Expr<SomeDerived> &&operand) -> Expr<SomeType> {
+                      using IntType = Type<TypeCategory::Integer, 4>;
+                      DynamicType enumType{*xDerived};
+                      DynamicType intResultType{TypeCategory::Integer, 4};
+                      characteristics::DummyDataObject ddo{
+                          characteristics::TypeAndShape{enumType}};
+                      ddo.intent = common::Intent::In;
+                      characteristics::Procedure::Attrs attrs;
+                      attrs.set(characteristics::Procedure::Attr::Pure);
+                      attrs.set(characteristics::Procedure::Attr::Elemental);
+                      characteristics::DummyArguments dummies;
+                      dummies.emplace_back("a"s, std::move(ddo));
+                      SpecificIntrinsic intSpec{"int"s,
+                          characteristics::Procedure{
+                              characteristics::FunctionResult{intResultType},
+                              std::move(dummies), attrs}};
+                      ActualArguments intArgs;
+                      intArgs.emplace_back(AsGenericExpr(std::move(operand)));
+                      return AsGenericExpr(
+                          Expr<SomeInteger>(Expr<IntType>(FunctionRef<IntType>{
+                              ProcedureDesignator{std::move(intSpec)},
+                              std::move(intArgs)})));
+                    };
+                    return Relate(messages, opr, makeIntCall(std::move(dx)),
+                        makeIntCall(std::move(dy)));
+                  }
+                }
+              }
+            }
+            DIE("invalid types for relational operator");
+            return std::optional<Expr<LogicalResult>>{};
+          },
           // Default case
           [&](auto &&, auto &&) {
             DIE("invalid types for relational operator");
diff --git a/flang/lib/Semantics/check-case.cpp b/flang/lib/Semantics/check-case.cpp
index 9004d8b3a28f9..fa0c08c317ba1 100644
--- a/flang/lib/Semantics/check-case.cpp
+++ b/flang/lib/Semantics/check-case.cpp
@@ -11,6 +11,7 @@
 #include "flang/Common/reference.h"
 #include "flang/Common/template.h"
 #include "flang/Evaluate/fold.h"
+#include "flang/Evaluate/tools.h"
 #include "flang/Evaluate/type.h"
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/semantics.h"
@@ -236,6 +237,103 @@ template <TypeCategory CAT> struct TypeVisitor {
   const std::list<parser::CaseConstruct::Case> &caseList;
 };
 
+// Convert a single enumeration CASE value to its __ordinal integer.
+static bool ConvertEnumCaseValue(SemanticsContext &context,
+    const parser::CaseValue &caseValue,
+    const semantics::DerivedTypeSpec &enumType,
+    const semantics::Symbol &ordSym) {
+  const auto &expr{parser::UnwrapRef<parser::Expr>(caseValue)};
+  auto *x{expr.typedExpr.get()};
+  if (!x || !x->v) {
+    return false;
+  }
+  auto type{x->v->GetType()};
+  if (!type || type->category() != TypeCategory::Derived) {
+    std::string typeStr{type ? type->AsFortran() : "typeless"s};
+    context.Say(expr.source,
+        "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US,
+        typeStr, enumType.AsFortran());
+    return false;
+  }
+  const auto *caseDerived{evaluate::GetDerivedTypeSpec(*type)};
+  if (!caseDerived || !caseDerived->IsEnumerationType() ||
+      &caseDerived->typeSymbol() != &enumType.typeSymbol()) {
+    context.Say(expr.source,
+        "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US,
+        type->AsFortran(), enumType.AsFortran());
+    return false;
+  }
+  // Extract the ordinal integer from the constant enum value
+  parser::Messages buffer;
+  parser::ContextualMessages foldingMessages{expr.source, &buffer};
+  evaluate::FoldingContext foldingContext{
+      context.foldingContext(), foldingMessages};
+  auto folded{evaluate::Fold(foldingContext, SomeExpr{*x->v})};
+  if (auto sc{
+          evaluate::GetScalarConstantValue<evaluate::SomeDerived>(folded)}) {
+    if (auto ordExpr{sc->Find(ordSym)}) {
+      x->v = std::move(*ordExpr);
+      return true;
+    }
+  }
+  context.Say(expr.source,
+      "CASE value (%s) must be a constant scalar"_err_en_US, x->v->AsFortran());
+  return false;
+}
+
+// Walk all CASE values in an enumeration SELECT CASE, check type
+// compatibility, and convert each to its ordinal integer value.
+static bool ConvertEnumCaseValues(SemanticsContext &context,
+    const std::list<parser::CaseConstruct::Case> &cases,
+    const semantics::DerivedTypeSpec &enumType) {
+  const auto *scope{enumType.GetScope()};
+  if (!scope) {
+    return false;
+  }
+  auto ordIter{scope->find(semantics::SourceName{"__ordinal", 9})};
+  if (ordIter == scope->end()) {
+    return false;
+  }
+  const semantics::Symbol &ordSym{*ordIter->second};
+  bool ok{true};
+  for (const auto &c : cases) {
+    const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)};
+    const auto &selector{std::get<parser::CaseSelector>(stmt.statement.t)};
+    common::visit(common::visitors{
+                      [&](const std::list<parser::CaseValueRange> &ranges) {
+                        for (const auto &range : ranges) {
+                          common::visit(
+                              common::visitors{
+                                  [&](const parser::CaseValue &val) {
+                                    if (!ConvertEnumCaseValue(
+                                            context, val, enumType, ordSym)) {
+                                      ok = false;
+                                    }
+                                  },
+                                  [&](const parser::CaseValueRange::Range &r) {
+                                    const auto &[lower, upper]{r.t};
+                                    if (lower &&
+                                        !ConvertEnumCaseValue(context, *lower,
+                                            enumType, ordSym)) {
+                                      ok = false;
+                                    }
+                                    if (upper &&
+                                        !ConvertEnumCaseValue(context, *upper,
+                                            enumType, ordSym)) {
+                                      ok = false;
+                                    }
+                                  },
+                              },
+                              range.u);
+                        }
+                      },
+                      [](const parser::Default &) {},
+                  },
+        selector.u);
+  }
+  return ok;
+}
+
 void CaseChecker::Enter(const parser::CaseConstruct &construct) {
   const auto &selectCaseStmt{
       std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)};
@@ -266,13 +364,26 @@ void CaseChecker::Enter(const parser::CaseConstruct &construct) {
       common::SearchTypes(
           TypeVisitor<TypeCategory::Character>{context_, *exprType, caseList});
       return;
+    case TypeCategory::Derived:
+      if (const auto *derived{evaluate::GetDerivedTypeSpec(*exprType)}) {
+        if (derived->IsEnumerationType()) {
+          if (ConvertEnumCaseValues(context_, caseList, *derived)) {
+            evaluate::DynamicType intType{TypeCategory::Integer, 4};
+            CaseValues<evaluate::Type<TypeCategory::Integer, 4>>{
+                context_, intType}
+                .Check(caseList);
+          }
+          return;
+        }
+      }
+      break;
     default:
       break;
     }
   }
   context_.Say(selectExpr.source,
       context_.IsEnabled(common::LanguageFeature::Unsigned)
-          ? "SELECT CASE expression must be integer, unsigned, logical, or character"_err_en_US
-          : "SELECT CASE expression must be integer, logical, or character"_err_en_US);
+          ? "SELECT CASE expression must be integer, unsigned, logical, character, or enumeration type"_err_en_US
+          : "SELECT CASE expression must be integer, logical, character, or enumeration type"_err_en_US);
 }
 } // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp
index 6a2fd40aeec79..a11b6bc4d691c 100644
--- a/flang/lib/Semantics/check-declarations.cpp
+++ b/flang/lib/Semantics/check-declarations.cpp
@@ -1841,6 +1841,10 @@ void CheckHelper::CheckExternal(const Symbol &symbol) {
 
 void CheckHelper::CheckDerivedType(
     const Symbol &derivedType, const DerivedTypeDetails &details) {
+  if (details.isEnumerationType()) {
+    // Enumeration types have no components, parameters, or bindings to check.
+    return;
+  }
   if (details.isForwardReferenced() && !context_.HasError(derivedType)) {
     messages_.Say("The derived type '%s' has not been defined"_err_en_US,
         derivedType.name());
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index 50869a3c870ef..26bd9fa3b382e 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -1482,6 +1482,12 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::StructureComponent &sc) {
   const auto &name{sc.Component().source};
   if (auto *dtExpr{UnwrapExpr<Expr<SomeDerived>>(*base)}) {
     const auto *dtSpec{GetDerivedTypeSpec(dtExpr->GetType())};
+    if (dtSpec && dtSpec->IsEnumerationType()) {
+      Say(name,
+          "Component reference is not allowed for enumeration type '%s'"_err_en_US,
+          dtSpec->typeSymbol().name());
+      return std::nullopt;
+    }
     if (isTypeParamInquiry) {
       if (auto *designator{UnwrapExpr<Designator<SomeDerived>>(*dtExpr)}) {
         if (std::optional<DynamicType> dyType{DynamicType::From(*sym)}) {
@@ -2462,6 +2468,65 @@ MaybeExpr ExpressionAnalyzer::CheckStructureConstructor(
   return AsMaybeExpr(Expr<SomeDerived>{std::move(result)});
 }
 
+// F2023 R771: enumeration-constructor is enumeration-type-spec (
+// scalar-int-expr ) The scalar-int-expr shall have a value that is positive and
+// less than or equal to the number of enumerators in the enumeration type.
+MaybeExpr ExpressionAnalyzer::AnalyzeEnumerationConstructor(
+    parser::CharBlock typeName, const semantics::DerivedTypeSpec &spec,
+    const std::list<parser::ComponentSpec> &components) {
+  const semantics::Symbol &typeSymbol{spec.typeSymbol()};
+  const auto &typeDetails{typeSymbol.get<semantics::DerivedTypeDetails>()};
+  int enumeratorCount{typeDetails.enumeratorCount()};
+  // Validate: exactly one positional argument, no keywords
+  if (components.size() != 1) {
+    Say(typeName,
+        "Enumeration constructor for '%s' requires exactly one argument"_err_en_US,
+        typeName);
+    return std::nullopt;
+  }
+  const auto &component{components.front()};
+  if (std::get<std::optional<parser::Keyword>>(component.t)) {
+    Say(typeName,
+        "Enumeration constructor for '%s' may not have a keyword argument"_err_en_US,
+        typeName);
+    return std::nullopt;
+  }
+  // Analyze the argument as a scalar integer expression
+  const parser::Expr &argExpr{
+      std::get<parser::ComponentDataSource>(component.t).v.value()};
+  auto restorer{GetContextualMessages().SetLocation(argExpr.source)};
+  MaybeExpr analyzed{Analyze(argExpr)};
+  if (!analyzed) {
+    return std::nullopt;
+  }
+  auto folded{Fold(std::move(*analyzed))};
+  auto argType{folded.GetType()};
+  if (!argType || argType->category() != TypeCategory::Integer) {
+    Say(argExpr.source,
+        "Enumeration constructor argument must be INTEGER, but is %s"_err_en_US,
+        argType ? argType->AsFortran() : std::string{"typeless"});
+    return std::nullopt;
+  }
+  // If the value is known at compile time, validate the range
+  if (auto value{ToInt64(folded)}) {
+    if (*value < 1 || *value > enumeratorCount) {
+      Say(argExpr.source,
+          "Enumeration constructor value (%jd) for '%s' must be positive and less than or equal to the number of enumerators (%d)"_err_en_US,
+          static_cast<std::intmax_t>(*value), typeName, enumeratorCount);
+      return std::nullopt;
+    }
+  }
+  // Produce an Expr<SomeDerived> with the ordinal in the __ordinal component
+  StructureConstructor result{spec};
+  if (const auto *scope{spec.GetScope()}) {
+    auto ordinalIter{scope->find(semantics::SourceName{"__ordinal", 9})};
+    if (ordinalIter != scope->end()) {
+      result.Add(*ordinalIter->second, std::move(folded));
+    }
+  }
+  return AsMaybeExpr(Expr<SomeDerived>{std::move(result)});
+}
+
 MaybeExpr ExpressionAnalyzer::Analyze(
     const parser::StructureConstructor &structure) {
   const auto &parsedType{std::get<parser::DerivedTypeSpec>(structure.t)};
@@ -2478,6 +2543,11 @@ MaybeExpr ExpressionAnalyzer::Analyze(
   if (!parsedType.derivedTypeSpec) {
     return std::nullopt;
   }
+  // F2023 R771: Enumeration constructor — enum_name(scalar-int-expr)
+  if (parsedType.derivedTypeSpec->IsEnumerationType()) {
+    return AnalyzeEnumerationConstructor(typeName, *parsedType.derivedTypeSpec,
+        std::get<std::list<parser::ComponentSpec>>(structure.t));
+  }
   auto restorer{AllowNullPointer()}; // NULL() can be a valid component
   std::list<ComponentSpec> componentSpecs;
   for (const auto &component :
@@ -3569,6 +3639,14 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::FunctionReference &funcRef,
         if (!CheckIsValidForwardReference(dtSpec)) {
           return std::nullopt;
         }
+        // Detect enumeration types and set the category accordingly
+        if (const auto *dtDetails{
+                symbol.detailsIf<semantics::DerivedTypeDetails>()}) {
+          if (dtDetails->isEnumerationType()) {
+            dtSpec.set_category(
+                semantics::DerivedTypeSpec::Category::EnumerationType);
+          }
+        }
         const semantics::DeclTypeSpec &type{
             semantics::FindOrInstantiateDerivedType(scope, std::move(dtSpec))};
         auto &mutableRef{const_cast<parser::FunctionReference &>(funcRef)};
diff --git a/flang/lib/Semantics/resolve-labels.cpp b/flang/lib/Semantics/resolve-labels.cpp
index 2da42b2f26cb1..3449b3d34f903 100644
--- a/flang/lib/Semantics/resolve-labels.cpp
+++ b/flang/lib/Semantics/resolve-labels.cpp
@@ -559,6 +559,14 @@ class ParseTreeAnalyzer {
     PopDisposableMap();
   }
 
+  // F2023 C7115
+  void Post(const parser::EnumerationTypeDef &enumTypeDef) {
+    CheckOptionalName<parser::EnumerationTypeStmt>(
+        "enumeration type definition", enumTypeDef,
+        std::get<parser::Statement<parser::EndEnumerationTypeStmt>>(
+            enumTypeDef.t));
+  }
+
   void Post(const parser::LabelDoStmt &labelDoStmt) {
     AddLabelReferenceFromDoStmt(std::get<parser::Label>(labelDoStmt.t));
   }
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 2a2073f29a26e..2db60f7e84d9f 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -1001,6 +1001,9 @@ class DeclarationVisitor : public ArraySpecVisitor,
   void Post(const parser::EnumDef &);
   bool Pre(const parser::Enumerator &);
   bool Pre(const parser::EnumerationTypeDef &);
+  void Post(const parser::EnumerationTypeStmt &);
+  bool Pre(const parser::EnumerationEnumeratorStmt &);
+  void Post(const parser::EndEnumerationTypeStmt &);
   bool Pre(const parser::AccessSpec &);
   bool Pre(const parser::AsynchronousStmt &);
   bool Pre(const parser::ContiguousStmt &);
@@ -6011,14 +6014,93 @@ bool DeclarationVisitor::Pre(const parser::Enumerator &enumerator) {
   return false;
 }
 
+void DeclarationVisitor::Post(const parser::EnumDef &) {
+  enumerationState_ = EnumeratorState{};
+}
+
+// F2023 R766 EnumerationTypeDef — scope is pushed in Post(EnumerationTypeStmt)
+// and popped in Post(EndEnumerationTypeStmt).
 bool DeclarationVisitor::Pre(const parser::EnumerationTypeDef &x) {
+  BeginAttrs();
+  // TODO: Remove this and set true when ENUMERATION TYPEs are implemented.
   Say(std::get<parser::Statement<parser::EnumerationTypeStmt>>(x.t).source,
       "F2023 ENUMERATION TYPEs are not yet implemented"_err_en_US);
   return false;
 }
 
-void DeclarationVisitor::Post(const parser::EnumDef &) {
-  enumerationState_ = EnumeratorState{};
+// F2023 R767 EnumerationTypeStmt — create the enumeration type symbol
+// in the enclosing scope and push a DerivedType scope for it.
+void DeclarationVisitor::Post(const parser::EnumerationTypeStmt &x) {
+  const auto &name{std::get<parser::Name>(x.t)};
+  Attrs attrs{EndAttrs()};
+  if (const auto &optAccessSpec{
+          std::get<std::optional<parser::AccessSpec>>(x.t)};
+      optAccessSpec) {
+    if (!NonDerivedTypeScope().IsModule()) { // F2023 C7114
+      Say(currStmtSource().value(),
+          "Access specifier on ENUMERATION TYPE may only appear in the specification part of a module"_err_en_US);
+    }
+  }
+  DerivedTypeDetails details;
+  details.set_isEnumerationType(true);
+  auto &symbol{MakeSymbol(name, attrs, std::move(details))};
+  symbol.ReplaceName(name.source);
+  PushScope(Scope::Kind::DerivedType, &symbol);
+  // Add a hidden __ordinal component to hold the 1-based enumerator position.
+  // This is a compiler-created INTEGER(4) component that preserves ordinal
+  // identity through constant folding and enables enumerator comparison.
+  SourceName ordinalName{context().SaveTempName(std::string{"__ordinal"})};
+  Symbol &ordinalSym{MakeSymbol(currScope(), ordinalName, Attrs{})};
+  ordinalSym.set_details(ObjectEntityDetails{});
+  ordinalSym.SetType(
+      currScope().MakeNumericType(TypeCategory::Integer, KindExpr{4}));
+  ordinalSym.set(Symbol::Flag::CompilerCreated);
+  symbol.get<DerivedTypeDetails>().add_component(ordinalSym);
+}
+
+// F2023 R768 EnumerationEnumeratorStmt — create PARAMETER symbols for
+// each enumerator name in the enclosing scope with 1-based ordinal init.
+bool DeclarationVisitor::Pre(const parser::EnumerationEnumeratorStmt &x) {
+  Scope &enclosingScope{NonDerivedTypeScope()};
+  // The current DerivedType scope's symbol is the enumeration type.
+  Symbol *typeSymbol{currScope().symbol()};
+  CHECK(typeSymbol);
+  auto &typeDetails{typeSymbol->get<DerivedTypeDetails>()};
+  // Build a DerivedTypeSpec for the enumeration type.
+  DerivedTypeSpec enumTypeSpec{typeSymbol->name(), *typeSymbol};
+  enumTypeSpec.set_category(DerivedTypeSpec::Category::EnumerationType);
+  DeclTypeSpec &declType{enclosingScope.MakeDerivedType(
+      DeclTypeSpec::TypeDerived, std::move(enumTypeSpec))};
+  for (const parser::Name &name : x.v) {
+    int ordinal{typeDetails.enumeratorCount() + 1};
+    // Create the enumerator symbol in the enclosing scope, not the
+    // enumeration type's own DerivedType scope.
+    Symbol &enumerator{
+        MakeSymbol(enclosingScope, name.source, Attrs{Attr::PARAMETER})};
+    Resolve(name, enumerator);
+    enumerator.set_details(ObjectEntityDetails{});
+    enumerator.SetType(declType);
+    // Store the init as a StructureConstructor of the enumeration type with
+    // the ordinal in the hidden __ordinal component.  This gives each
+    // enumerator a distinct Constant<SomeDerived> value.
+    evaluate::StructureConstructor enumCtor{declType.derivedTypeSpec()};
+    // Look up the __ordinal component symbol in the type's scope.
+    auto ordinalIter{currScope().find(SourceName{"__ordinal", 9})};
+    CHECK(ordinalIter != currScope().end());
+    const Symbol &ordinalSym{*ordinalIter->second};
+    enumCtor.Add(ordinalSym,
+        evaluate::AsGenericExpr(evaluate::Expr<evaluate::CInteger>{ordinal}));
+    enumerator.get<ObjectEntityDetails>().set_init(
+        SomeExpr{evaluate::Expr<evaluate::SomeDerived>{
+            evaluate::Constant<evaluate::SomeDerived>{std::move(enumCtor)}}});
+    typeDetails.set_enumeratorCount(ordinal);
+  }
+  return false;
+}
+
+// F2023 R769 EndEnumerationTypeStmt — pop the scope.
+void DeclarationVisitor::Post(const parser::EndEnumerationTypeStmt &) {
+  PopScope();
 }
 
 bool DeclarationVisitor::Pre(const parser::AccessSpec &x) {
@@ -6610,6 +6692,17 @@ void DeclarationVisitor::Post(const parser::DerivedTypeSpec &x) {
   // in the current scope, this spec will be moved into that collection.
   const auto &dtDetails{spec->typeSymbol().get<DerivedTypeDetails>()};
   auto category{GetDeclTypeSpecCategory()};
+
+  // Enumeration types are a special case of derived types and are handled
+  // differently.
+  if (dtDetails.isEnumerationType()) {
+    spec->set_category(DerivedTypeSpec::Category::EnumerationType);
+    DeclTypeSpec &type{currScope().MakeDerivedType(category, std::move(*spec))};
+    SetDeclTypeSpec(type);
+    x.derivedTypeSpec = &GetDeclTypeSpec()->derivedTypeSpec();
+    return;
+  }
+
   if (dtDetails.isForwardReferenced()) {
     DeclTypeSpec &type{currScope().MakeDerivedType(category, std::move(*spec))};
     SetDeclTypeSpec(type);
@@ -8930,6 +9023,12 @@ class ExecutionPartSkimmerBase {
     return true;
   }
   void Post(const parser::DerivedTypeDef &) { PopScope(); }
+  bool Pre(const parser::EnumerationTypeStmt &x) {
+    Hide(std::get<parser::Name>(x.t));
+    PushScope();
+    return true;
+  }
+  void Post(const parser::EnumerationTypeDef &) { PopScope(); }
   bool Pre(const parser::SelectTypeConstruct &) {
     PushScope();
     return true;
@@ -9404,6 +9503,12 @@ const parser::Name *DeclarationVisitor::FindComponent(
       return &component;
     }
   } else if (DerivedTypeSpec * derived{type->AsDerived()}) {
+    if (derived->IsEnumerationType()) {
+      Say(component.source,
+          "Component reference is not allowed for enumeration type '%s'"_err_en_US,
+          derived->typeSymbol().name());
+      return nullptr;
+    }
     derived->Instantiate(currScope()); // in case of forward referenced type
     if (const Scope * scope{derived->scope()}) {
       if (Resolve(component, scope->FindComponent(component.source))) {
@@ -11099,6 +11204,25 @@ class DeferredCheckVisitor {
     }
   }
 
+  void Post(const parser::EnumerationTypeStmt &x) {
+    const auto &name{std::get<parser::Name>(x.t)};
+    if (Symbol * symbol{name.symbol}) {
+      if (Scope * scope{symbol->scope()}) {
+        if (scope->IsDerivedType()) {
+          CHECK(outerScope_ == nullptr);
+          outerScope_ = &resolver_.currScope();
+          resolver_.SetScope(*scope);
+        }
+      }
+    }
+  }
+  void Post(const parser::EndEnumerationTypeStmt &) {
+    if (outerScope_) {
+      resolver_.SetScope(*outerScope_);
+      outerScope_ = nullptr;
+    }
+  }
+
   void Post(const parser::ProcInterface &pi) {
     if (const auto *name{std::get_if<parser::Name>(&pi.u)}) {
       resolver_.CheckExplicitInterface(*name);
diff --git a/flang/lib/Semantics/rewrite-parse-tree.cpp b/flang/lib/Semantics/rewrite-parse-tree.cpp
index 4e1c9bae9c153..fd323cbb0177c 100644
--- a/flang/lib/Semantics/rewrite-parse-tree.cpp
+++ b/flang/lib/Semantics/rewrite-parse-tree.cpp
@@ -81,6 +81,7 @@ class RewriteMutator {
   bool Pre(parser::EndSubmoduleStmt &) { return false; }
   bool Pre(parser::EndSubroutineStmt &) { return false; }
   bool Pre(parser::EndTypeStmt &) { return false; }
+  bool Pre(parser::EndEnumerationTypeStmt &) { return false; }
 
   bool Pre(parser::OmpBlockConstruct &);
   bool Pre(parser::OpenMPLoopConstruct &);
diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp
index 79511c93b79b4..c965dc0d1c32d 100644
--- a/flang/lib/Semantics/tools.cpp
+++ b/flang/lib/Semantics/tools.cpp
@@ -182,9 +182,18 @@ bool IsIntrinsicRelational(common::RelationalOperator opr,
       return opr == common::RelationalOperator::EQ ||
           opr == common::RelationalOperator::NE ||
           (cat0 != TypeCategory::Complex && cat1 != TypeCategory::Complex);
+    } else if (cat0 == TypeCategory::Character &&
+        cat1 == TypeCategory::Character) {
+      return true;
+    } else if (cat0 == TypeCategory::Derived && cat1 == TypeCategory::Derived) {
+      // Same enumeration type: all six relational operators are allowed
+      const auto *derived0{evaluate::GetDerivedTypeSpec(type0)};
+      const auto *derived1{evaluate::GetDerivedTypeSpec(type1)};
+      return derived0 && derived1 && derived0->IsEnumerationType() &&
+          derived1->IsEnumerationType() &&
+          &derived0->typeSymbol() == &derived1->typeSymbol();
     } else {
-      // not both numeric: only Character is ok
-      return cat0 == TypeCategory::Character && cat1 == TypeCategory::Character;
+      return false;
     }
   }
 }
diff --git a/flang/test/Semantics/case01.f90 b/flang/test/Semantics/case01.f90
index c9631d299e49c..7caa453ef6252 100644
--- a/flang/test/Semantics/case01.f90
+++ b/flang/test/Semantics/case01.f90
@@ -45,7 +45,7 @@ program selectCaseProg
    end select
 
   ! C1145
-  !ERROR: SELECT CASE expression must be integer, logical, or character
+  !ERROR: SELECT CASE expression must be integer, logical, character, or enumeration type
   select case (grade4)
      case (1.0)
      case (2.0)
@@ -53,7 +53,7 @@ program selectCaseProg
      case default
   end select
 
-  !ERROR: SELECT CASE expression must be integer, logical, or character
+  !ERROR: SELECT CASE expression must be integer, logical, character, or enumeration type
   select case (score)
      case (score_val)
      case (scores(100))

>From d704a6ef5f9001e226709e2b5535dfbe08e768d8 Mon Sep 17 00:00:00 2001
From: Kevin Wyatt <kwyatt at hpe.com>
Date: Mon, 20 Apr 2026 11:13:28 -0500
Subject: [PATCH 2/4] Adding tests and intrinsic piece required for
 relationals.

---
 flang/lib/Evaluate/fold-integer.cpp           |  28 +++++
 flang/lib/Semantics/resolve-names.cpp         |   5 +-
 .../enumeration-type-declarations.f90         |  84 +++++++++++++
 .../Semantics/enumeration-type-relational.f90 | 117 ++++++++++++++++++
 4 files changed, 230 insertions(+), 4 deletions(-)
 create mode 100644 flang/test/Semantics/enumeration-type-declarations.f90
 create mode 100644 flang/test/Semantics/enumeration-type-relational.f90

diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index 9f2bb94a9213f..d5dcf272d53d7 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -761,6 +761,34 @@ std::optional<Expr<T>> FoldIntrinsicFunctionCommon(
   } else if (name == "int" || name == "int2" || name == "int8" ||
       name == "uint") {
     if (auto *expr{UnwrapExpr<Expr<SomeType>>(args[0])}) {
+      // Check for enumeration type argument first — extract __ordinal
+      if (auto *derivedExpr{std::get_if<Expr<SomeDerived>>(&expr->u)}) {
+        if (auto type{derivedExpr->GetType()}) {
+          if (const auto *derived{GetDerivedTypeSpec(*type)}) {
+            if (derived->IsEnumerationType()) {
+              if (const auto *scope{derived->GetScope()}) {
+                auto ordIter{
+                    scope->find(semantics::SourceName{"__ordinal", 9})};
+                if (ordIter != scope->end()) {
+                  const semantics::Symbol &ordSym{*ordIter->second};
+                  if (auto *constant{
+                          UnwrapConstantValue<SomeDerived>(*derivedExpr)}) {
+                    if (auto sc{constant->GetScalarValue()}) {
+                      if (auto ordExpr{sc->Find(ordSym)}) {
+                        if (auto ordVal{ToInt64(*ordExpr)}) {
+                          return Expr<T>{Constant<T>{Scalar<T>{*ordVal}}};
+                        }
+                      }
+                    }
+                  }
+                }
+              }
+              // Non-constant enumeration argument — leave unfolded
+              return Expr<T>{std::move(funcRef)};
+            }
+          }
+        }
+      }
       return common::visit(
           [&](auto &&x) -> Expr<T> {
             using From = std::decay_t<decltype(x)>;
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 2db60f7e84d9f..c1e57e9dedc32 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -6022,10 +6022,7 @@ void DeclarationVisitor::Post(const parser::EnumDef &) {
 // and popped in Post(EndEnumerationTypeStmt).
 bool DeclarationVisitor::Pre(const parser::EnumerationTypeDef &x) {
   BeginAttrs();
-  // TODO: Remove this and set true when ENUMERATION TYPEs are implemented.
-  Say(std::get<parser::Statement<parser::EnumerationTypeStmt>>(x.t).source,
-      "F2023 ENUMERATION TYPEs are not yet implemented"_err_en_US);
-  return false;
+  return true;
 }
 
 // F2023 R767 EnumerationTypeStmt — create the enumeration type symbol
diff --git a/flang/test/Semantics/enumeration-type-declarations.f90 b/flang/test/Semantics/enumeration-type-declarations.f90
new file mode 100644
index 0000000000000..de66ae888268e
--- /dev/null
+++ b/flang/test/Semantics/enumeration-type-declarations.f90
@@ -0,0 +1,84 @@
+! RUN: %python %S/test_errors.py %s %flang_fc1
+! Test declaration, constructor, and expression semantics for enumeration types
+
+! C7114: access specifier only allowed in module
+subroutine test_access_specifier_outside_module()
+  !ERROR: PRIVATE attribute may only appear in the specification part of a module
+  !ERROR: Access specifier on ENUMERATION TYPE may only appear in the specification part of a module
+  enumeration type, private :: color
+    enumerator :: red, green, blue
+  end enumeration type
+end subroutine
+
+! Valid: basic declarations and usage
+subroutine test_basic_declarations()
+  enumeration type :: color
+    enumerator :: red, green, blue
+  end enumeration type
+
+  type(color) :: c1, c2
+  logical :: l
+
+  ! Valid: assign an enumerator
+  c1 = red
+  c2 = blue
+
+  ! Valid: comparison produces logical
+  l = (c1 == c2)
+  l = (c1 /= red)
+end subroutine
+
+! Valid: constructor syntax — color(n) where n is a positive integer <= count
+subroutine test_constructor_valid()
+  enumeration type :: color
+    enumerator :: red, green, blue
+  end enumeration type
+
+  type(color) :: c
+
+  ! Valid: integer constructor in range
+  c = color(1)
+  c = color(2)
+  c = color(3)
+end subroutine
+
+! Constructor errors
+subroutine test_constructor_errors()
+  enumeration type :: color
+    enumerator :: red, green, blue
+  end enumeration type
+
+  type(color) :: c
+
+  ! ERROR: Enumeration constructor for 'color' requires exactly one argument
+  c = color()
+
+  ! ERROR: Enumeration constructor for 'color' requires exactly one argument
+  c = color(1, 2)
+
+  ! ERROR: Enumeration constructor for 'color' may not have a keyword argument
+  c = color(val=1)
+
+  ! ERROR: Enumeration constructor argument must be INTEGER, but is REAL(4)
+  c = color(1.0)
+
+  ! ERROR: Enumeration constructor value (0) for 'color' must be positive and less than or equal to the number of enumerators (3)
+  c = color(0)
+
+  ! ERROR: Enumeration constructor value (4) for 'color' must be positive and less than or equal to the number of enumerators (3)
+  c = color(4)
+end subroutine
+
+! Component reference on enumeration type is not allowed
+subroutine test_component_reference()
+  enumeration type :: color
+    enumerator :: red, green, blue
+  end enumeration type
+
+  type(color) :: c
+  integer :: i
+
+  c = red
+  ! ERROR: Component reference is not allowed for enumeration type 'color'
+  i = c%__ordinal
+end subroutine
diff --git a/flang/test/Semantics/enumeration-type-relational.f90 b/flang/test/Semantics/enumeration-type-relational.f90
new file mode 100644
index 0000000000000..507635c6bbdd1
--- /dev/null
+++ b/flang/test/Semantics/enumeration-type-relational.f90
@@ -0,0 +1,117 @@
+! RUN: %python %S/test_errors.py %s %flang_fc1
+! Test relational operators and SELECT CASE for enumeration types (F2023 7.6.2)
+
+module enum_mod
+  enumeration type :: color
+    enumerator :: red, green, blue
+  end enumeration type
+
+  enumeration type :: direction
+    enumerator :: north, south, east, west
+  end enumeration type
+
+  enumeration type :: w_value
+    enumerator :: w1, w2, w3, w4, w5
+  end enumeration type
+end module
+
+subroutine test_relational_same_type()
+  use enum_mod
+  logical :: result
+
+  ! Valid: all six relational operators between same-type enumerators
+  result = red == red
+  result = red /= green
+  result = red < green
+  result = green > red
+  result = red <= red
+  result = blue >= green
+end subroutine
+
+subroutine test_relational_cross_type()
+  use enum_mod
+
+  ! ERROR: Operands of .EQ. must have comparable types; have TYPE(color) and TYPE(direction)
+  if (red == north) stop 1
+
+  ! ERROR: Operands of .LT. must have comparable types; have TYPE(color) and TYPE(direction)
+  if (red < north) stop 2
+end subroutine
+
+subroutine test_relational_enum_vs_integer()
+  use enum_mod
+
+  ! ERROR: Operands of .EQ. must have comparable types; have TYPE(color) and INTEGER(4)
+  if (red == 1) stop 1
+
+  ! ERROR: Operands of .EQ. must have comparable types; have INTEGER(4) and TYPE(color)
+  if (1 == red) stop 2
+end subroutine
+
+subroutine test_select_case_basic(w)
+  use enum_mod
+  type(w_value), intent(in) :: w
+
+  ! Valid: SELECT CASE with enumerator names as case values
+  select case (w)
+    case (w1)
+      print *, 'w1'
+    case (w2)
+      print *, 'w2'
+    case default
+      print *, 'other'
+  end select
+end subroutine
+
+subroutine test_select_case_range(w)
+  use enum_mod
+  type(w_value), intent(in) :: w
+
+  ! Valid: SELECT CASE with ranges
+  select case (w)
+    case (w1)
+      print *, 'w1'
+    case (w2:w4)
+      print *, 'w2 to w4'
+    case (w5)
+      print *, 'w5'
+  end select
+end subroutine
+
+subroutine test_select_case_wrong_enum(w)
+  use enum_mod
+  type(w_value), intent(in) :: w
+
+  select case (w)
+    !ERROR: CASE value has type 'color' which is not compatible with the SELECT CASE expression's type 'ENUMERATION TYPE :: w_value'
+    case (red)
+      print *, 'wrong'
+    case default
+      print *, 'ok'
+  end select
+end subroutine
+
+subroutine test_select_case_integer_case(w)
+  use enum_mod
+  type(w_value), intent(in) :: w
+
+  select case (w)
+    !ERROR: CASE value has type 'INTEGER(4)' which is not compatible with the SELECT CASE expression's type 'ENUMERATION TYPE :: w_value'
+    case (1)
+      print *, 'wrong'
+    case default
+      print *, 'ok'
+  end select
+end subroutine
+
+subroutine test_select_case_non_enum_derived()
+  type :: my_type
+    integer :: val
+  end type
+  type(my_type) :: x = my_type(1)
+
+  !ERROR: SELECT CASE expression must be integer, logical, character, or enumeration type
+  select case (x)
+    case default
+  end select
+end subroutine

>From a0be48f5137e4dedce4270056796bdc7586bf91d Mon Sep 17 00:00:00 2001
From: Kevin Wyatt <kwyatt at hpe.com>
Date: Wed, 10 Jun 2026 14:54:05 -0500
Subject: [PATCH 3/4] Per Review: Enforce rank-0 and convert to the component's
 declared kind.

---
 flang/lib/Semantics/expression.cpp                | 15 +++++++++++++--
 .../Semantics/enumeration-type-declarations.f90   |  3 +++
 2 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index 26bd9fa3b382e..ba630b8d29c9e 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2507,6 +2507,13 @@ MaybeExpr ExpressionAnalyzer::AnalyzeEnumerationConstructor(
         argType ? argType->AsFortran() : std::string{"typeless"});
     return std::nullopt;
   }
+  // F2023 R771: the argument shall be a scalar-int-expr.
+  if (folded.Rank() > 0) {
+    Say(argExpr.source,
+        "Enumeration constructor argument for '%s' must be scalar"_err_en_US,
+        typeName);
+    return std::nullopt;
+  }
   // If the value is known at compile time, validate the range
   if (auto value{ToInt64(folded)}) {
     if (*value < 1 || *value > enumeratorCount) {
@@ -2516,12 +2523,16 @@ MaybeExpr ExpressionAnalyzer::AnalyzeEnumerationConstructor(
       return std::nullopt;
     }
   }
-  // Produce an Expr<SomeDerived> with the ordinal in the __ordinal component
+  // Produce an Expr<SomeDerived> with the ordinal in the __ordinal component,
+  // converted to the component's declared integer kind.
   StructureConstructor result{spec};
   if (const auto *scope{spec.GetScope()}) {
     auto ordinalIter{scope->find(semantics::SourceName{"__ordinal", 9})};
     if (ordinalIter != scope->end()) {
-      result.Add(*ordinalIter->second, std::move(folded));
+      const Symbol &ordinalSymbol{*ordinalIter->second};
+      if (auto converted{ConvertToType(ordinalSymbol, std::move(folded))}) {
+        result.Add(ordinalSymbol, std::move(*converted));
+      }
     }
   }
   return AsMaybeExpr(Expr<SomeDerived>{std::move(result)});
diff --git a/flang/test/Semantics/enumeration-type-declarations.f90 b/flang/test/Semantics/enumeration-type-declarations.f90
index de66ae888268e..a03d8dcca78c0 100644
--- a/flang/test/Semantics/enumeration-type-declarations.f90
+++ b/flang/test/Semantics/enumeration-type-declarations.f90
@@ -62,6 +62,9 @@ subroutine test_constructor_errors()
   ! ERROR: Enumeration constructor argument must be INTEGER, but is REAL(4)
   c = color(1.0)
 
+  ! ERROR: Enumeration constructor argument for 'color' must be scalar
+  c = color([1, 2])
+
   ! ERROR: Enumeration constructor value (0) for 'color' must be positive and less than or equal to the number of enumerators (3)
   c = color(0)
 

>From caa39263c6faa1a0cb5a5435d8deded9d3f16a0f Mon Sep 17 00:00:00 2001
From: Kevin Wyatt <kwyatt at hpe.com>
Date: Tue, 23 Jun 2026 16:57:27 -0500
Subject: [PATCH 4/4] Modified ExpressionAnalyzer::Analyze derived type check
 to use ultimate symbol, added a guard in IsEnumerationType to use ultimate,
 and added test cases for USE-associated enumeration types.

---
 flang/lib/Semantics/expression.cpp            |  8 ++--
 flang/lib/Semantics/tools.cpp                 |  4 +-
 .../enumeration-type-declarations.f90         | 37 +++++++++++++++++++
 3 files changed, 45 insertions(+), 4 deletions(-)

diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index ba630b8d29c9e..60b7a6ccbd605 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -3650,9 +3650,11 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::FunctionReference &funcRef,
         if (!CheckIsValidForwardReference(dtSpec)) {
           return std::nullopt;
         }
-        // Detect enumeration types and set the category accordingly
-        if (const auto *dtDetails{
-                symbol.detailsIf<semantics::DerivedTypeDetails>()}) {
+        // Detect enumeration types and set the category accordingly.
+        // Use the ultimate symbol so that a USE-associated enumeration type
+        // (whose local symbol carries UseDetails) is recognized too.
+        if (const auto *dtDetails{symbol.GetUltimate()
+                    .detailsIf<semantics::DerivedTypeDetails>()}) {
           if (dtDetails->isEnumerationType()) {
             dtSpec.set_category(
                 semantics::DerivedTypeSpec::Category::EnumerationType);
diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp
index c965dc0d1c32d..92bd599fdab2f 100644
--- a/flang/lib/Semantics/tools.cpp
+++ b/flang/lib/Semantics/tools.cpp
@@ -1086,7 +1086,9 @@ bool IsAssumedType(const Symbol &symbol) {
 }
 
 bool IsEnumerationType(const Symbol &symbol) {
-  if (const auto *details{symbol.detailsIf<DerivedTypeDetails>()}) {
+  // Use the ultimate symbol for cases such as USE-associated enumeration types
+  if (const auto *details{
+          symbol.GetUltimate().detailsIf<DerivedTypeDetails>()}) {
     return details->isEnumerationType();
   }
   return false;
diff --git a/flang/test/Semantics/enumeration-type-declarations.f90 b/flang/test/Semantics/enumeration-type-declarations.f90
index a03d8dcca78c0..72375dc3c6836 100644
--- a/flang/test/Semantics/enumeration-type-declarations.f90
+++ b/flang/test/Semantics/enumeration-type-declarations.f90
@@ -85,3 +85,40 @@ subroutine test_component_reference()
   ! ERROR: Component reference is not allowed for enumeration type 'color'
   i = c%__ordinal
 end subroutine
+
+! Module providing an enumeration type by USE association
+module enum_constructor_mod
+  enumeration type :: color
+    enumerator :: red, green, blue
+  end enumeration type
+end module
+
+! Constructor errors for a USE-associated enumeration type.
+! This exercises the cross-module path: the type's local symbol carries
+! UseDetails, so the enumeration-specific checks must follow USE association.
+subroutine test_constructor_errors_use()
+  use enum_constructor_mod
+
+  type(color) :: c
+
+  ! ERROR: Enumeration constructor for 'color' requires exactly one argument
+  c = color()
+
+  ! ERROR: Enumeration constructor for 'color' requires exactly one argument
+  c = color(1, 2)
+
+  ! ERROR: Enumeration constructor for 'color' may not have a keyword argument
+  c = color(val=1)
+
+  ! ERROR: Enumeration constructor argument must be INTEGER, but is REAL(4)
+  c = color(1.0)
+
+  ! ERROR: Enumeration constructor argument for 'color' must be scalar
+  c = color([1, 2])
+
+  ! ERROR: Enumeration constructor value (0) for 'color' must be positive and less than or equal to the number of enumerators (3)
+  c = color(0)
+
+  ! ERROR: Enumeration constructor value (4) for 'color' must be positive and less than or equal to the number of enumerators (3)
+  c = color(4)
+end subroutine



More information about the flang-commits mailing list