[flang-commits] [flang] [flang] Implement conditional expressions parser/semantics (F2023) (PR #186489)
Caroline Newcombe via flang-commits
flang-commits at lists.llvm.org
Fri Mar 20 13:06:52 PDT 2026
https://github.com/cenewcombe updated https://github.com/llvm/llvm-project/pull/186489
>From ad5bebc18ca8883e577560a98029921a482a76c1 Mon Sep 17 00:00:00 2001
From: Caroline Newcombe <caroline.newcombe at hpe.com>
Date: Fri, 13 Mar 2026 11:19:56 -0500
Subject: [PATCH 1/2] [flang] Implement conditional expressions
parser/semantics (F2023)
Implements Fortran 2023 conditional expressions (R1002):
result = (condition ? value1 : condition2 ? value2 : ... : elseValue)
This adds:
- Parser support for conditional expression syntax using ? and :
- Semantic analysis with type checking (all values must have matching
type, kind, and rank; conditions must be scalar logical)
- ConditionalExpr node in the expression tree with N conditions and
N+1 values (last value is the else branch)
- Lowering to HLFIR implemented in a separate branch/PR
- LIT test coverage for semantics and parsing
Current limitations:
- Conditional expressions as actual arguments are not yet implemented
- Polymorphic (CLASS) types not yet supported
This implements llvm#176999
---
flang/examples/FeatureList/FeatureList.cpp | 1 +
flang/include/flang/Evaluate/expression.h | 49 ++-
flang/include/flang/Evaluate/shape.h | 10 +
flang/include/flang/Evaluate/tools.h | 28 ++
flang/include/flang/Evaluate/traverse.h | 4 +
flang/include/flang/Parser/characters.h | 1 +
flang/include/flang/Parser/dump-parse-tree.h | 2 +
flang/include/flang/Parser/parse-tree.h | 24 +-
flang/include/flang/Semantics/dump-expr.h | 22 ++
flang/include/flang/Semantics/expression.h | 1 +
flang/lib/Evaluate/check-expression.cpp | 139 +++++++
flang/lib/Evaluate/expression.cpp | 23 ++
flang/lib/Evaluate/formatting.cpp | 14 +
flang/lib/Evaluate/shape.cpp | 63 +++-
flang/lib/Evaluate/tools.cpp | 121 ++++++
flang/lib/Lower/ConvertExpr.cpp | 10 +
flang/lib/Lower/ConvertExprToHLFIR.cpp | 6 +
flang/lib/Lower/IterationSpace.cpp | 11 +
flang/lib/Lower/Support/Utils.cpp | 32 ++
flang/lib/Parser/basic-parsers.h | 32 ++
flang/lib/Parser/expr-parsers.cpp | 12 +
flang/lib/Parser/unparse.cpp | 12 +
flang/lib/Semantics/check-cuda.cpp | 30 ++
flang/lib/Semantics/check-data.cpp | 15 +
flang/lib/Semantics/check-do-forall.cpp | 11 +
flang/lib/Semantics/definable.cpp | 14 +
flang/lib/Semantics/expression.cpp | 216 +++++++++++
flang/lib/Semantics/openmp-utils.cpp | 25 ++
flang/lib/Semantics/resolve-names-utils.cpp | 14 +
flang/test/Parser/conditional-expr.f90 | 261 +++++++++++++
flang/test/Semantics/conditional-expr.f90 | 365 +++++++++++++++++++
31 files changed, 1548 insertions(+), 20 deletions(-)
create mode 100644 flang/test/Parser/conditional-expr.f90
create mode 100644 flang/test/Semantics/conditional-expr.f90
diff --git a/flang/examples/FeatureList/FeatureList.cpp b/flang/examples/FeatureList/FeatureList.cpp
index 355d79a04e4ba..bee18096f9fb2 100644
--- a/flang/examples/FeatureList/FeatureList.cpp
+++ b/flang/examples/FeatureList/FeatureList.cpp
@@ -311,6 +311,7 @@ struct NodeVisitor {
READ_FEATURE(Expr::NEQV)
READ_FEATURE(Expr::DefinedBinary)
READ_FEATURE(Expr::ComplexConstructor)
+ READ_FEATURE(ConditionalExpr)
READ_FEATURE(External)
READ_FEATURE(ExternalStmt)
READ_FEATURE(FailImageStmt)
diff --git a/flang/include/flang/Evaluate/expression.h b/flang/include/flang/Evaluate/expression.h
index f7a1f9b955181..d46699cb7ac2c 100644
--- a/flang/include/flang/Evaluate/expression.h
+++ b/flang/include/flang/Evaluate/expression.h
@@ -390,6 +390,36 @@ struct LogicalOperation
LogicalOperator logicalOperator;
};
+// Fortran 2023 conditional expression: (cond ? val : cond ? val : ... : else)
+// All branches have the same type and rank (verified during semantic analysis).
+template <typename T> class ConditionalExpr {
+public:
+ using Result = T;
+ CLASS_BOILERPLATE(ConditionalExpr)
+ ConditionalExpr(
+ std::vector<Expr<SomeLogical>> &&conds, std::vector<Expr<Result>> &&vals)
+ : conditions_{std::move(conds)}, values_{std::move(vals)} {
+ CHECK(values_.size() == conditions_.size() + 1);
+ }
+ bool operator==(const ConditionalExpr &) const;
+ const std::vector<Expr<SomeLogical>> &conditions() const {
+ return conditions_;
+ }
+ std::vector<Expr<SomeLogical>> &conditions() { return conditions_; }
+ const std::vector<Expr<Result>> &values() const { return values_; }
+ std::vector<Expr<Result>> &values() { return values_; }
+ int Rank() const { return values_.empty() ? 0 : values_.front().Rank(); }
+ std::optional<DynamicType> GetType() const {
+ return values_.empty() ? std::nullopt : values_.front().GetType();
+ }
+ static constexpr int Corank() { return 0; }
+ llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
+
+private:
+ std::vector<Expr<SomeLogical>> conditions_; // size N
+ std::vector<Expr<Result>> values_; // size N+1 (includes else)
+};
+
// Array constructors
template <typename RESULT> class ArrayConstructorValues;
@@ -536,7 +566,7 @@ class Expr<Type<TypeCategory::Integer, KIND>>
Convert<Result, TypeCategory::Unsigned>>;
using Operations = std::tuple<Parentheses<Result>, Negate<Result>,
Add<Result>, Subtract<Result>, Multiply<Result>, Divide<Result>,
- Power<Result>, Extremum<Result>>;
+ Power<Result>, Extremum<Result>, ConditionalExpr<Result>>;
using Indices = std::conditional_t<KIND == ImpliedDoIndex::Result::kind,
std::tuple<ImpliedDoIndex>, std::tuple<>>;
using TypeParamInquiries =
@@ -568,7 +598,7 @@ class Expr<Type<TypeCategory::Unsigned, KIND>>
Convert<Result, TypeCategory::Unsigned>>;
using Operations = std::tuple<Parentheses<Result>, Negate<Result>,
Add<Result>, Subtract<Result>, Multiply<Result>, Divide<Result>,
- Power<Result>, Extremum<Result>>;
+ Power<Result>, Extremum<Result>, ConditionalExpr<Result>>;
using Others = std::tuple<Constant<Result>, ArrayConstructor<Result>,
Designator<Result>, FunctionRef<Result>>;
@@ -594,7 +624,8 @@ class Expr<Type<TypeCategory::Real, KIND>>
Convert<Result, TypeCategory::Unsigned>>;
using Operations = std::variant<ComplexComponent<KIND>, Parentheses<Result>,
Negate<Result>, Add<Result>, Subtract<Result>, Multiply<Result>,
- Divide<Result>, Power<Result>, RealToIntPower<Result>, Extremum<Result>>;
+ Divide<Result>, Power<Result>, RealToIntPower<Result>, Extremum<Result>,
+ ConditionalExpr<Result>>;
using Others = std::variant<Constant<Result>, ArrayConstructor<Result>,
Designator<Result>, FunctionRef<Result>>;
@@ -612,7 +643,7 @@ class Expr<Type<TypeCategory::Complex, KIND>>
using Operations = std::variant<Parentheses<Result>, Negate<Result>,
Convert<Result, TypeCategory::Complex>, Add<Result>, Subtract<Result>,
Multiply<Result>, Divide<Result>, Power<Result>, RealToIntPower<Result>,
- ComplexConstructor<KIND>>;
+ ComplexConstructor<KIND>, ConditionalExpr<Result>>;
using Others = std::variant<Constant<Result>, ArrayConstructor<Result>,
Designator<Result>, FunctionRef<Result>>;
@@ -638,7 +669,7 @@ class Expr<Type<TypeCategory::Character, KIND>>
std::variant<Constant<Result>, ArrayConstructor<Result>, Designator<Result>,
FunctionRef<Result>, Parentheses<Result>, Convert<Result>, Concat<KIND>,
- Extremum<Result>, SetLength<KIND>>
+ Extremum<Result>, SetLength<KIND>, ConditionalExpr<Result>>
u;
};
@@ -710,7 +741,7 @@ class Expr<Type<TypeCategory::Logical, KIND>>
private:
using Operations = std::tuple<Convert<Result>, Parentheses<Result>, Not<KIND>,
- LogicalOperation<KIND>>;
+ LogicalOperation<KIND>, ConditionalExpr<Result>>;
using Relations = std::conditional_t<KIND == LogicalResult::kind,
std::tuple<Relational<SomeType>>, std::tuple<>>;
using Others = std::tuple<Constant<Result>, ArrayConstructor<Result>,
@@ -788,7 +819,8 @@ template <> class Expr<SomeDerived> : public ExpressionBase<SomeDerived> {
using Result = SomeDerived;
EVALUATE_UNION_CLASS_BOILERPLATE(Expr)
std::variant<Constant<Result>, ArrayConstructor<Result>, StructureConstructor,
- Designator<Result>, FunctionRef<Result>, Parentheses<Result>>
+ Designator<Result>, FunctionRef<Result>, Parentheses<Result>,
+ ConditionalExpr<Result>>
u;
};
@@ -929,6 +961,7 @@ FOR_EACH_INTRINSIC_KIND(extern template class ArrayConstructor, )
template class Relational<SomeType>; \
FOR_EACH_TYPE_AND_KIND(template class ExpressionBase, ) \
FOR_EACH_INTRINSIC_KIND(template class ArrayConstructorValues, ) \
- FOR_EACH_INTRINSIC_KIND(template class ArrayConstructor, )
+ FOR_EACH_INTRINSIC_KIND(template class ArrayConstructor, ) \
+ FOR_EACH_INTRINSIC_KIND(template class ConditionalExpr, )
} // namespace Fortran::evaluate
#endif // FORTRAN_EVALUATE_EXPRESSION_H_
diff --git a/flang/include/flang/Evaluate/shape.h b/flang/include/flang/Evaluate/shape.h
index f0505cfcdf2d7..3af78820f6c66 100644
--- a/flang/include/flang/Evaluate/shape.h
+++ b/flang/include/flang/Evaluate/shape.h
@@ -189,6 +189,16 @@ class GetShapeHelper
Result operator()(const ArrayConstructor<T> &aconst) const {
return Shape{GetArrayConstructorExtent(aconst)};
}
+ template <typename T>
+ Result operator()(const ConditionalExpr<T> &conditional) const {
+ // Per F2023 10.1.4(7), the shape is determined by the selected branch,
+ // so return unknown extents for the rank.
+ if (!conditional.values().empty()) {
+ int rank{conditional.values().front().Rank()};
+ return Shape(rank, std::nullopt);
+ }
+ return ScalarShape();
+ }
template <typename D, typename R, typename LO, typename RO>
Result operator()(const Operation<D, R, LO, RO> &operation) const {
if (int rr{operation.right().Rank()}; rr > 0) {
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 0fded08456bcf..9ba29a5a2879c 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -50,6 +50,9 @@ struct IsVariableHelper
Result operator()(const CoarrayRef &) const { return true; }
Result operator()(const ComplexPart &) const { return true; }
Result operator()(const ProcedureDesignator &) const;
+ template <typename T> Result operator()(const ConditionalExpr<T> &) const {
+ return false;
+ }
template <typename T> Result operator()(const Expr<T> &x) const {
if constexpr (common::HasMember<T, AllIntrinsicTypes> ||
std::is_same_v<T, SomeDerived>) {
@@ -1071,6 +1074,16 @@ struct GetSymbolVectorHelper
Result operator()(const Component &) const;
Result operator()(const ArrayRef &) const;
Result operator()(const CoarrayRef &) const;
+ template <typename T> Result operator()(const ConditionalExpr<T> &x) {
+ Result result;
+ for (const auto &cond : x.conditions()) {
+ result = Combine(std::move(result), (*this)(cond));
+ }
+ for (const auto &val : x.values()) {
+ result = Combine(std::move(result), (*this)(val));
+ }
+ return result;
+ }
};
template <typename A> SymbolVector GetSymbolVector(const A &x) {
return GetSymbolVectorHelper{}(x);
@@ -1159,6 +1172,20 @@ class UnsafeToCopyVisitor : public AnyTraverse<UnsafeToCopyVisitor> {
return !admitPureCall_ || !procRef.proc().IsPure();
}
bool operator()(const CoarrayRef &) { return true; }
+ template <typename T> bool operator()(const ConditionalExpr<T> &x) {
+ // A conditional expression is unsafe to copy if any of its parts are unsafe
+ for (const auto &condition : x.conditions()) {
+ if ((*this)(condition)) {
+ return true;
+ }
+ }
+ for (const auto &value : x.values()) {
+ if ((*this)(value)) {
+ return true;
+ }
+ }
+ return false;
+ }
private:
bool admitPureCall_{false};
@@ -1381,6 +1408,7 @@ enum class Operator {
Call,
Constant,
Convert,
+ Conditional,
Div,
Eq,
Eqv,
diff --git a/flang/include/flang/Evaluate/traverse.h b/flang/include/flang/Evaluate/traverse.h
index d63c16f93230a..306337274bf1f 100644
--- a/flang/include/flang/Evaluate/traverse.h
+++ b/flang/include/flang/Evaluate/traverse.h
@@ -224,6 +224,10 @@ class Traverse {
Result operator()(const StructureConstructor &x) const {
return visitor_.Combine(visitor_(x.derivedTypeSpec()), CombineContents(x));
}
+ // Conditional expressions (Fortran 2023)
+ template <typename T> Result operator()(const ConditionalExpr<T> &x) const {
+ return Combine(x.conditions(), x.values());
+ }
// Operations and wrappers
// Have a single operator() for all Operations.
diff --git a/flang/include/flang/Parser/characters.h b/flang/include/flang/Parser/characters.h
index 3761700ad348c..620c6b357f948 100644
--- a/flang/include/flang/Parser/characters.h
+++ b/flang/include/flang/Parser/characters.h
@@ -170,6 +170,7 @@ inline constexpr bool IsValidFortranTokenCharacter(char ch) {
case '<':
case '=':
case '>':
+ case '?': // Used in conditional expressions (Fortran 2023)
case '[':
case ']':
case '{': // Used in OpenMP context selector specification
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 84c7b8d2a5349..4f6142af7eaf6 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -252,6 +252,8 @@ class ParseTreeDumper {
NODE(parser, ComputedGotoStmt)
NODE(parser, ConcurrentControl)
NODE(parser, ConcurrentHeader)
+ NODE(parser, ConditionalExpr)
+ NODE(ConditionalExpr, Branch)
NODE(parser, ConnectSpec)
NODE(ConnectSpec, CharExpr)
NODE_ENUM(ConnectSpec::CharExpr, Kind)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 4aec99c80bdae..706d31dca3331 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -1678,6 +1678,19 @@ struct ImageSelector {
std::tuple<std::list<Cosubscript>, std::list<ImageSelectorSpec>> t;
};
+// F2023: R1002 conditional-expr ->
+// ( scalar-logical-expr ? expr
+// [ : scalar-logical-expr ? expr ]...
+// : expr )
+struct ConditionalExpr {
+ TUPLE_CLASS_BOILERPLATE(ConditionalExpr);
+ struct Branch {
+ TUPLE_CLASS_BOILERPLATE(Branch);
+ std::tuple<ScalarLogicalExpr, common::Indirection<Expr>> t;
+ };
+ std::tuple<std::list<Branch>, common::Indirection<Expr>> t;
+};
+
// R1001 - R1022 expressions
struct Expr {
UNION_CLASS_BOILERPLATE(Expr);
@@ -1776,11 +1789,12 @@ struct Expr {
CharBlock source;
std::variant<common::Indirection<CharLiteralConstantSubstring>,
- LiteralConstant, common::Indirection<Designator>, ArrayConstructor,
- StructureConstructor, common::Indirection<FunctionReference>, Parentheses,
- UnaryPlus, Negate, NOT, PercentLoc, DefinedUnary, Power, Multiply, Divide,
- Add, Subtract, Concat, LT, LE, EQ, NE, GE, GT, AND, OR, EQV, NEQV,
- DefinedBinary, ComplexConstructor, common::Indirection<SubstringInquiry>>
+ LiteralConstant, ConditionalExpr, common::Indirection<Designator>,
+ ArrayConstructor, StructureConstructor,
+ common::Indirection<FunctionReference>, Parentheses, UnaryPlus, Negate,
+ NOT, PercentLoc, DefinedUnary, Power, Multiply, Divide, Add, Subtract,
+ Concat, LT, LE, EQ, NE, GE, GT, AND, OR, EQV, NEQV, DefinedBinary,
+ ComplexConstructor, common::Indirection<SubstringInquiry>>
u;
};
diff --git a/flang/include/flang/Semantics/dump-expr.h b/flang/include/flang/Semantics/dump-expr.h
index 8cbb78b585f4a..5fbed77139958 100644
--- a/flang/include/flang/Semantics/dump-expr.h
+++ b/flang/include/flang/Semantics/dump-expr.h
@@ -201,6 +201,28 @@ class DumpEvaluateExpr {
Show(op.right());
Outdent();
}
+ template <typename T> void Show(const evaluate::ConditionalExpr<T> &x) {
+ Indent("conditional expr "s + std::string(TypeOf<T>::name));
+ const auto &conds = x.conditions();
+ const auto &vals = x.values();
+ // Show condition-value pairs
+ for (const auto &[cond, val] : llvm::zip(conds, vals)) {
+ Indent("branch");
+ Indent("condition");
+ Show(cond);
+ Outdent();
+ Indent("value");
+ Show(val);
+ Outdent();
+ Outdent();
+ }
+ if (!vals.empty()) {
+ Indent("default value");
+ Show(vals.back());
+ Outdent();
+ }
+ Outdent();
+ }
void Show(const evaluate::Relational<evaluate::SomeType> &x);
template <typename T> void Show(const evaluate::Expr<T> &x) {
Indent("expr <"s + std::string(TypeOf<T>::name) + ">"s);
diff --git a/flang/include/flang/Semantics/expression.h b/flang/include/flang/Semantics/expression.h
index 490399aa03ff8..0054a86486e79 100644
--- a/flang/include/flang/Semantics/expression.h
+++ b/flang/include/flang/Semantics/expression.h
@@ -169,6 +169,7 @@ class ExpressionAnalyzer {
MaybeExpr Analyze(const parser::DataStmtValue &);
MaybeExpr Analyze(const parser::AllocateObject &);
MaybeExpr Analyze(const parser::PointerObject &);
+ MaybeExpr Analyze(const parser::ConditionalExpr &);
template <typename A> MaybeExpr Analyze(const common::Indirection<A> &x) {
return Analyze(x.value());
diff --git a/flang/lib/Evaluate/check-expression.cpp b/flang/lib/Evaluate/check-expression.cpp
index e73a4d82951af..7c5e7c129765c 100644
--- a/flang/lib/Evaluate/check-expression.cpp
+++ b/flang/lib/Evaluate/check-expression.cpp
@@ -92,6 +92,23 @@ class IsConstantExprHelper
!sym.attrs().test(semantics::Attr::VALUE)));
}
+ template <typename T>
+ bool operator()(const ConditionalExpr<T> &conditional) const {
+ // A conditional expression is constant if all its conditions and values are
+ // constant
+ for (const auto &condition : conditional.conditions()) {
+ if (!(*this)(condition)) {
+ return false;
+ }
+ }
+ for (const auto &value : conditional.values()) {
+ if (!(*this)(value)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
bool operator()(const ImpliedDoIndex &ido) const {
return acImpliedDos_.find(ido.name) != acImpliedDos_.end() || !context_ ||
context_->GetImpliedDo(ido.name).has_value();
@@ -229,6 +246,20 @@ struct IsActuallyConstantHelper {
template <typename T> bool operator()(const Parentheses<T> &x) {
return (*this)(x.left());
}
+ template <typename T> bool operator()(const ConditionalExpr<T> &x) {
+ // A conditional expression is actually constant if all its parts are
+ for (const auto &condition : x.conditions()) {
+ if (!(*this)(condition)) {
+ return false;
+ }
+ }
+ for (const auto &value : x.values()) {
+ if (!(*this)(value)) {
+ return false;
+ }
+ }
+ return true;
+ }
template <typename T> bool operator()(const Expr<T> &x) {
return common::visit([=](const auto &y) { return (*this)(y); }, x.u);
}
@@ -358,6 +389,10 @@ class IsInitialDataTargetHelper
bool operator()(const Operation<D, R, O...> &) const {
return false;
}
+ template <typename T> bool operator()(const ConditionalExpr<T> &) const {
+ // A conditional expression cannot be an initial data target
+ return false;
+ }
template <typename T> bool operator()(const Parentheses<T> &x) const {
return (*this)(x.left());
}
@@ -492,6 +527,21 @@ class SuspiciousRealLiteralFinder
}
return (*this)(x.left());
}
+ template <typename T> bool operator()(const ConditionalExpr<T> &x) const {
+ // Check all conditions and values in the conditional expression for
+ // suspicious literals
+ for (const auto &cond : x.conditions()) {
+ if ((*this)(cond)) {
+ return true;
+ }
+ }
+ for (const auto &value : x.values()) {
+ if ((*this)(value)) {
+ return true;
+ }
+ }
+ return false;
+ }
private:
int kind_;
@@ -531,6 +581,16 @@ class InexactLiteralConversionFlagClearer
mut.set_isFromInexactLiteralConversion(false);
return false;
}
+ template <typename T> bool operator()(const ConditionalExpr<T> &x) const {
+ // Clear flags in all conditions and values of the conditional expression
+ for (const auto &cond : x.conditions()) {
+ (*this)(cond);
+ }
+ for (const auto &value : x.values()) {
+ (*this)(value);
+ }
+ return false;
+ }
};
// Converts, folds, and then checks type, rank, and shape of an
@@ -798,6 +858,20 @@ class CheckSpecificationExprHelper
return std::nullopt;
}
+ template <typename T> Result operator()(const ConditionalExpr<T> &x) const {
+ for (const auto &cond : x.conditions()) {
+ if (auto result{(*this)(cond)}) {
+ return result;
+ }
+ }
+ for (const auto &val : x.values()) {
+ if (auto result{(*this)(val)}) {
+ return result;
+ }
+ }
+ return std::nullopt;
+ }
+
Result operator()(const ProcedureRef &x) const {
if (const auto *symbol{x.proc().GetSymbol()}) {
const Symbol &ultimate{symbol->GetUltimate()};
@@ -1193,6 +1267,31 @@ class IsContiguousHelper
Result operator()(const NullPointer &) const { return true; }
+ template <typename T> Result operator()(const ConditionalExpr<T> &x) {
+ // Track contiguity across all possible runtime branches
+ bool hasContiguous{false};
+ bool hasNonContiguous{false};
+ bool hasUnknown{false};
+ for (const auto &val : x.values()) {
+ auto result{(*this)(val)};
+ if (!result) {
+ hasUnknown = true;
+ } else if (*result) {
+ hasContiguous = true;
+ } else {
+ hasNonContiguous = true;
+ }
+ }
+ // Return definite result only if all values have uniform contiguity
+ if (hasUnknown || (hasContiguous && hasNonContiguous)) {
+ return std::nullopt;
+ } else if (hasContiguous) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+
private:
// Returns "true" for a provably empty or simply contiguous array section;
// return "false" for a provably nonempty discontiguous section or for use
@@ -1386,6 +1485,20 @@ struct IsErrorExprHelper : public AnyTraverse<IsErrorExprHelper, bool> {
bool operator()(const SpecificIntrinsic &x) {
return x.name == IntrinsicProcTable::InvalidName;
}
+
+ template <typename T> bool operator()(const ConditionalExpr<T> &x) {
+ for (const auto &cond : x.conditions()) {
+ if ((*this)(cond)) {
+ return true;
+ }
+ }
+ for (const auto &val : x.values()) {
+ if ((*this)(val)) {
+ return true;
+ }
+ }
+ return false;
+ }
};
template <typename A> bool IsErrorExpr(const A &x) {
@@ -1501,6 +1614,20 @@ class StmtFunctionChecker
return std::nullopt;
}
+ template <typename T> Result operator()(const ConditionalExpr<T> &x) {
+ for (const auto &cond : x.conditions()) {
+ if (auto result{(*this)(cond)}) {
+ return result;
+ }
+ }
+ for (const auto &val : x.values()) {
+ if (auto result{(*this)(val)}) {
+ return result;
+ }
+ }
+ return std::nullopt;
+ }
+
private:
const Symbol &sf_;
FoldingContext &context_;
@@ -1760,6 +1887,18 @@ class CollectUsedSymbolValuesHelper
return {}; // doesn't count as a use
}
+ template <typename T> Result operator()(const ConditionalExpr<T> &condExpr) {
+ auto restorer{common::ScopedSet(isDefinition_, false)};
+ Result result;
+ for (const auto &cond : condExpr.conditions()) {
+ result = Combine(std::move(result), (*this)(cond));
+ }
+ for (const auto &val : condExpr.values()) {
+ result = Combine(std::move(result), (*this)(val));
+ }
+ return result;
+ }
+
private:
static bool IsBindingUsedAsProcedure(const Expr<SomeType> &expr) {
if (const auto *pd{std::get_if<ProcedureDesignator>(&expr.u)}) {
diff --git a/flang/lib/Evaluate/expression.cpp b/flang/lib/Evaluate/expression.cpp
index 759fe5bc71b69..e5289d14d63fc 100644
--- a/flang/lib/Evaluate/expression.cpp
+++ b/flang/lib/Evaluate/expression.cpp
@@ -64,6 +64,24 @@ Expr<Type<TypeCategory::Character, KIND>>::LEN() const {
}
return std::nullopt;
},
+ [](const ConditionalExpr<Result> &c) -> T {
+ // Return max of all branch lengths. If all have same constant
+ // length, max folds to constant; otherwise signals deferred-length.
+ std::optional<Expr<SubscriptInteger>> maxLen;
+ for (const auto &value : c.values()) {
+ if (auto len{value.LEN()}) {
+ if (maxLen) {
+ maxLen = Expr<SubscriptInteger>{Extremum<SubscriptInteger>{
+ Ordering::Greater, std::move(*maxLen), std::move(*len)}};
+ } else {
+ maxLen = std::move(len);
+ }
+ } else {
+ return std::nullopt;
+ }
+ }
+ return maxLen;
+ },
[](const Designator<Result> &dr) { return dr.LEN(); },
[](const FunctionRef<Result> &fr) { return fr.LEN(); },
[](const SetLength<KIND> &x) -> T { return x.right(); },
@@ -141,6 +159,11 @@ template <typename A> bool Extremum<A>::operator==(const Extremum &that) const {
return ordering == that.ordering && Base::operator==(that);
}
+template <typename A>
+bool ConditionalExpr<A>::operator==(const ConditionalExpr &that) const {
+ return conditions_ == that.conditions_ && values_ == that.values_;
+}
+
template <int KIND>
bool LogicalOperation<KIND>::operator==(const LogicalOperation &that) const {
return logicalOperator == that.logicalOperator && Base::operator==(that);
diff --git a/flang/lib/Evaluate/formatting.cpp b/flang/lib/Evaluate/formatting.cpp
index 5632015857ab3..4c1002cf1cfc5 100644
--- a/flang/lib/Evaluate/formatting.cpp
+++ b/flang/lib/Evaluate/formatting.cpp
@@ -587,6 +587,20 @@ llvm::raw_ostream &ArrayConstructor<SomeDerived>::AsFortran(
return o << ']';
}
+template <typename T>
+llvm::raw_ostream &ConditionalExpr<T>::AsFortran(llvm::raw_ostream &o) const {
+ o << '(';
+ for (std::size_t i = 0; i < conditions_.size(); ++i) {
+ conditions_[i].AsFortran(o);
+ o << " ? ";
+ values_[i].AsFortran(o);
+ o << " : ";
+ }
+ // Last value is the else clause
+ values_.back().AsFortran(o);
+ return o << ')';
+}
+
template <typename RESULT>
std::string ExpressionBase<RESULT>::AsFortran() const {
std::string buf;
diff --git a/flang/lib/Evaluate/shape.cpp b/flang/lib/Evaluate/shape.cpp
index 27913c3559c71..e37213041a7e4 100644
--- a/flang/lib/Evaluate/shape.cpp
+++ b/flang/lib/Evaluate/shape.cpp
@@ -221,14 +221,63 @@ ConstantSubscript GetSize(const ConstantSubscripts &shape) {
return size;
}
+// Helper visitor for ContainsAnyImpliedDoIndex
+struct ImpliedDoIndexVisitor : public AnyTraverse<ImpliedDoIndexVisitor> {
+ using Base = AnyTraverse<ImpliedDoIndexVisitor>;
+ ImpliedDoIndexVisitor() : Base{*this} {}
+ using Base::operator();
+ bool operator()(const ImpliedDoIndex &) { return true; }
+
+ // Template helper for ConditionalExpr handlers
+ template <typename T> bool CheckConditionalExpr(const ConditionalExpr<T> &x) {
+ for (const auto &cond : x.conditions()) {
+ if ((*this)(cond)) {
+ return true;
+ }
+ }
+ for (const auto &val : x.values()) {
+ if ((*this)(val)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ // ConditionalExpr handlers - check all conditions and values for implied DO
+ // indices
+ template <int KIND>
+ bool operator()(const ConditionalExpr<Type<TypeCategory::Integer, KIND>> &x) {
+ return CheckConditionalExpr(x);
+ }
+ template <int KIND>
+ bool operator()(const ConditionalExpr<Type<TypeCategory::Logical, KIND>> &x) {
+ return CheckConditionalExpr(x);
+ }
+ template <int KIND>
+ bool operator()(const ConditionalExpr<Type<TypeCategory::Real, KIND>> &x) {
+ return CheckConditionalExpr(x);
+ }
+ template <int KIND>
+ bool operator()(const ConditionalExpr<Type<TypeCategory::Complex, KIND>> &x) {
+ return CheckConditionalExpr(x);
+ }
+ template <int KIND>
+ bool operator()(
+ const ConditionalExpr<Type<TypeCategory::Unsigned, KIND>> &x) {
+ return CheckConditionalExpr(x);
+ }
+ template <int KIND>
+ bool operator()(
+ const ConditionalExpr<Type<TypeCategory::Character, KIND>> &x) {
+ return CheckConditionalExpr(x);
+ }
+ bool operator()(const ConditionalExpr<SomeKind<TypeCategory::Derived>> &x) {
+ return CheckConditionalExpr(x);
+ }
+};
+
bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
- struct MyVisitor : public AnyTraverse<MyVisitor> {
- using Base = AnyTraverse<MyVisitor>;
- MyVisitor() : Base{*this} {}
- using Base::operator();
- bool operator()(const ImpliedDoIndex &) { return true; }
- };
- return MyVisitor{}(expr);
+ return ImpliedDoIndexVisitor{}(expr);
}
// Determines lower bound on a dimension. This can be other than 1 only
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 9b7d4c758769e..4570d55d7e73d 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -1085,6 +1085,18 @@ struct CollectSymbolsHelper
semantics::UnorderedSymbolSet operator()(const Symbol &symbol) const {
return {symbol};
}
+ template <typename T>
+ semantics::UnorderedSymbolSet operator()(const ConditionalExpr<T> &x) {
+ // Collect symbols from all conditions and values
+ semantics::UnorderedSymbolSet result;
+ for (const auto &cond : x.conditions()) {
+ result.merge((*this)(cond));
+ }
+ for (const auto &val : x.values()) {
+ result.merge((*this)(val));
+ }
+ return result;
+ }
};
template <typename A> semantics::UnorderedSymbolSet CollectSymbols(const A &x) {
return CollectSymbolsHelper{}(x);
@@ -1118,6 +1130,18 @@ struct CollectCudaSymbolsHelper : public SetTraverse<CollectCudaSymbolsHelper,
semantics::UnorderedSymbolSet operator()(const ProcedureRef &) const {
return {};
}
+ template <typename T>
+ semantics::UnorderedSymbolSet operator()(const ConditionalExpr<T> &x) {
+ // Collect CUDA symbols from all conditions and values
+ semantics::UnorderedSymbolSet result;
+ for (const auto &cond : x.conditions()) {
+ result.merge((*this)(cond));
+ }
+ for (const auto &val : x.values()) {
+ result.merge((*this)(val));
+ }
+ return result;
+ }
};
template <typename A>
semantics::UnorderedSymbolSet CollectCudaSymbols(const A &x) {
@@ -1185,6 +1209,20 @@ struct HasVectorSubscriptHelper
bool operator()(const ProcedureRef &) const {
return false; // don't descend into function call arguments
}
+ template <typename T> bool operator()(const ConditionalExpr<T> &x) {
+ // Check if any condition or value has a vector subscript
+ for (const auto &cond : x.conditions()) {
+ if ((*this)(cond)) {
+ return true;
+ }
+ }
+ for (const auto &val : x.values()) {
+ if ((*this)(val)) {
+ return true;
+ }
+ }
+ return false;
+ }
};
bool HasVectorSubscript(const Expr<SomeType> &expr) {
@@ -1211,6 +1249,20 @@ struct HasConstantHelper : public AnyTraverse<HasConstantHelper, bool,
}
// Only look for constant not in subscript.
bool operator()(const Subscript &) const { return false; }
+ template <typename T> bool operator()(const ConditionalExpr<T> &x) {
+ // Check if any condition or value has a constant
+ for (const auto &cond : x.conditions()) {
+ if ((*this)(cond)) {
+ return true;
+ }
+ }
+ for (const auto &val : x.values()) {
+ if ((*this)(val)) {
+ return true;
+ }
+ }
+ return false;
+ }
};
bool HasConstant(const Expr<SomeType> &expr) {
@@ -1225,6 +1277,21 @@ struct HasStructureComponentHelper
using Base::operator();
bool operator()(const Component &) const { return true; }
+
+ template <typename T> bool operator()(const ConditionalExpr<T> &x) {
+ // Check if any condition or value has a structure component
+ for (const auto &cond : x.conditions()) {
+ if ((*this)(cond)) {
+ return true;
+ }
+ }
+ for (const auto &val : x.values()) {
+ if ((*this)(val)) {
+ return true;
+ }
+ }
+ return false;
+ }
};
bool HasStructureComponent(const Expr<SomeType> &expr) {
@@ -1291,6 +1358,21 @@ class FindImpureCallHelper
return call.proc().GetName();
}
+ template <typename T> Result operator()(const ConditionalExpr<T> &x) const {
+ // Check if any condition or value contains an impure call
+ for (const auto &cond : x.conditions()) {
+ if (auto result{(*this)(cond)}) {
+ return result;
+ }
+ }
+ for (const auto &val : x.values()) {
+ if (auto result{(*this)(val)}) {
+ return result;
+ }
+ }
+ return std::nullopt;
+ }
+
private:
FoldingContext &context_;
};
@@ -1726,6 +1808,18 @@ struct ArgumentExtractor
return {operation::OperationCode(x), {AsSomeExpr(x)}};
}
+ template <typename T> Result operator()(const ConditionalExpr<T> &x) const {
+ // ConditionalExpr is a top-level operation; collect its immediate operands
+ Arguments args;
+ for (const auto &cond : x.conditions()) {
+ args.push_back(AsSomeExpr(cond));
+ }
+ for (const auto &val : x.values()) {
+ args.push_back(AsSomeExpr(val));
+ }
+ return {Operator::Conditional, std::move(args)};
+ }
+
template <typename... Rs>
Result Combine(Result &&result, Rs &&...results) const {
// There shouldn't be any combining needed, since we're stopping the
@@ -1763,6 +1857,8 @@ std::string operation::ToString(operation::Operator op) {
return "ASSOCIATED";
case Operator::Call:
return "function-call";
+ case Operator::Conditional:
+ return "conditional";
case Operator::Constant:
return "constant";
case Operator::Convert:
@@ -1891,6 +1987,15 @@ struct ConvertCollector
}
}
+ template <typename T> Result operator()(const ConditionalExpr<T> &x) const {
+ // For conditional expressions, collect conversions from all values only
+ Result result;
+ for (const auto &val : x.values()) {
+ result = Combine(std::move(result), (*this)(val));
+ }
+ return result;
+ }
+
template <typename... Rs>
Result Combine(Result &&result, Rs &&...results) const {
Result v(std::move(result));
@@ -1983,6 +2088,22 @@ struct VariableFinder : public evaluate::AnyTraverse<VariableFinder> {
return evaluate::AsGenericExpr(common::Clone(x)) == var;
}
+ template <typename T>
+ bool operator()(const evaluate::ConditionalExpr<T> &x) const {
+ // Check if any condition or value contains the variable
+ for (const auto &cond : x.conditions()) {
+ if ((*this)(cond)) {
+ return true;
+ }
+ }
+ for (const auto &val : x.values()) {
+ if ((*this)(val)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
private:
const SomeExpr &var;
};
diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index a7e0239d335fd..32cd710e9b5b4 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -927,6 +927,11 @@ class ScalarExprLowering {
return builder.createNullConstant(getLoc());
}
+ template <typename A>
+ ExtValue genval(const Fortran::evaluate::ConditionalExpr<A> &) {
+ fir::emitFatalError(getLoc(), "ConditionalExpr should be lowered to HLFIR");
+ }
+
static bool
isDerivedTypeWithLenParameters(const Fortran::semantics::Symbol &sym) {
if (const Fortran::semantics::DeclTypeSpec *declTy = sym.GetType())
@@ -5366,6 +5371,11 @@ class ArrayExprLowering {
};
}
+ template <typename A>
+ CC genarr(const Fortran::evaluate::ConditionalExpr<A> &) {
+ fir::emitFatalError(getLoc(), "ConditionalExpr should be lowered to HLFIR");
+ }
+
template <typename T>
CC genarr(const Fortran::evaluate::Constant<T> &x) {
if (x.Rank() == 0)
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index 0c015bc9a2f1b..7ddd09e59c262 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -1788,6 +1788,12 @@ class HlfirBuilder {
TODO(getLoc(), "lowering type parameter inquiry to HLFIR");
}
+ template <typename T>
+ hlfir::EntityWithAttributes
+ gen(const Fortran::evaluate::ConditionalExpr<T> &) {
+ TODO(getLoc(), "lowering conditional expression to HLFIR");
+ }
+
hlfir::EntityWithAttributes
gen(const Fortran::evaluate::DescriptorInquiry &desc) {
mlir::Location loc = getLoc();
diff --git a/flang/lib/Lower/IterationSpace.cpp b/flang/lib/Lower/IterationSpace.cpp
index 203fec508f795..1f650a9fa5412 100644
--- a/flang/lib/Lower/IterationSpace.cpp
+++ b/flang/lib/Lower/IterationSpace.cpp
@@ -212,6 +212,17 @@ class ArrayBaseFinder {
(void)find(op.right());
return false;
}
+ template <typename T>
+ RT find(const Fortran::evaluate::ConditionalExpr<T> &x) {
+ // Find array bases in all conditions and values
+ for (const auto &cond : x.conditions()) {
+ (void)find(cond);
+ }
+ for (const auto &val : x.values()) {
+ (void)find(val);
+ }
+ return {};
+ }
RT find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) {
(void)find(x.u);
return {};
diff --git a/flang/lib/Lower/Support/Utils.cpp b/flang/lib/Lower/Support/Utils.cpp
index 384636a659875..230587f3b951c 100644
--- a/flang/lib/Lower/Support/Utils.cpp
+++ b/flang/lib/Lower/Support/Utils.cpp
@@ -158,6 +158,18 @@ class HashEvaluateExpr {
static_cast<unsigned>(TC) + static_cast<unsigned>(KIND) +
static_cast<unsigned>(x.ordering) * 7u;
}
+ template <typename T>
+ static unsigned getHashValue(const Fortran::evaluate::ConditionalExpr<T> &x) {
+ unsigned conds = 1u;
+ for (const auto &cond : x.conditions()) {
+ conds -= getHashValue(cond);
+ }
+ unsigned vals = 3u;
+ for (const auto &val : x.values()) {
+ vals += getHashValue(val);
+ }
+ return conds * 151u - vals;
+ }
template <Fortran::common::TypeCategory TC, int KIND>
static unsigned getHashValue(
const Fortran::evaluate::RealToIntPower<Fortran::evaluate::Type<TC, KIND>>
@@ -416,6 +428,26 @@ class IsEqualEvaluateExpr {
const Fortran::evaluate::Extremum<A> &y) {
return isBinaryEqual(x, y);
}
+ template <typename T>
+ static bool isEqual(const Fortran::evaluate::ConditionalExpr<T> &x,
+ const Fortran::evaluate::ConditionalExpr<T> &y) {
+ // Compare all conditions and values
+ if (x.conditions().size() != y.conditions().size() ||
+ x.values().size() != y.values().size()) {
+ return false;
+ }
+ for (size_t i = 0; i < x.conditions().size(); ++i) {
+ if (!isEqual(x.conditions()[i], y.conditions()[i])) {
+ return false;
+ }
+ }
+ for (size_t i = 0; i < x.values().size(); ++i) {
+ if (!isEqual(x.values()[i], y.values()[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
template <typename A>
static bool isEqual(const Fortran::evaluate::RealToIntPower<A> &x,
const Fortran::evaluate::RealToIntPower<A> &y) {
diff --git a/flang/lib/Parser/basic-parsers.h b/flang/lib/Parser/basic-parsers.h
index eeb59a830fc0c..3fde9f83b935e 100644
--- a/flang/lib/Parser/basic-parsers.h
+++ b/flang/lib/Parser/basic-parsers.h
@@ -835,6 +835,38 @@ struct NextCh {
constexpr NextCh nextCh;
+// Lookahead helper for conditional expressions: checks if input starting with
+// '(' contains '?' at nesting level 1. This avoids exponential backtracking
+// when parsing deeply nested parentheses that are not conditional expressions.
+struct ConditionalExprLookahead {
+ using resultType = Success;
+ constexpr ConditionalExprLookahead() {}
+ std::optional<Success> Parse(ParseState &state) const {
+ if (std::optional<const char *> at{state.PeekAtNextChar()}) {
+ if (**at != '(') {
+ return std::nullopt;
+ }
+ const char *const start{*at};
+ const char *const limit{start + state.BytesRemaining()};
+ int nestLevel{0};
+ for (const char *p{start}; p < limit; ++p) {
+ if (*p == '(') {
+ ++nestLevel;
+ } else if (*p == ')') {
+ --nestLevel;
+ if (nestLevel == 0) {
+ return std::nullopt;
+ }
+ } else if (*p == '?' && nestLevel == 1) {
+ return {Success{}};
+ }
+ }
+ }
+ return std::nullopt;
+ }
+};
+constexpr ConditionalExprLookahead conditionalExprLookahead;
+
// If a is a parser for some nonstandard language feature LF, extension<LF>(a)
// is a parser that optionally enabled, sets a strict conformance violation
// flag, and may emit a warning message, if those are enabled.
diff --git a/flang/lib/Parser/expr-parsers.cpp b/flang/lib/Parser/expr-parsers.cpp
index b6832a7999c5b..bc6bf609419b8 100644
--- a/flang/lib/Parser/expr-parsers.cpp
+++ b/flang/lib/Parser/expr-parsers.cpp
@@ -70,6 +70,7 @@ TYPE_PARSER(construct<AcImpliedDoControl>(
constexpr auto primary{instrumented("primary"_en_US,
first(construct<Expr>(indirect(charLiteralConstantSubstring)),
construct<Expr>(literalConstant),
+ construct<Expr>(Parser<ConditionalExpr>{}),
construct<Expr>(construct<Expr::Parentheses>("(" >>
expr / !","_tok / recovery(")"_tok, SkipPastNested<'(', ')'>{}))),
construct<Expr>(indirect(functionReference) / !"("_tok / !"%"_tok),
@@ -94,6 +95,17 @@ constexpr auto level1Expr{sourced(
primary || // must come before define op to resolve .TRUE._8 ambiguity
construct<Expr>(construct<Expr::DefinedUnary>(definedOpName, primary)))};
+// F2023: R1002 conditional-expr ->
+// ( scalar-logical-expr ? expr
+// [ : scalar-logical-expr ? expr ]...
+// : expr )
+TYPE_PARSER(conditionalExprLookahead >>
+ parenthesized(construct<ConditionalExpr>(
+ some(construct<ConditionalExpr::Branch>(
+ scalarLogicalExpr / "?", indirect(expr)) /
+ ":"),
+ indirect(expr))))
+
// R1004 mult-operand -> level-1-expr [power-op mult-operand]
// R1007 power-op -> **
// Exponentiation (**) is Fortran's only right-associative binary operation.
diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index 9d01bb74d70d3..e87b4c1f9b11d 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -900,6 +900,18 @@ class UnparseVisitor {
void Unparse(const Expr::OR &x) { Walk(x.t, ".OR."); }
void Unparse(const Expr::EQV &x) { Walk(x.t, ".EQV."); }
void Unparse(const Expr::NEQV &x) { Walk(x.t, ".NEQV."); }
+ void Unparse(const ConditionalExpr &x) { // F2023: R1002
+ Put("( ");
+ const auto &branches{std::get<std::list<ConditionalExpr::Branch>>(x.t)};
+ for (const auto &branch : branches) {
+ Walk(std::get<ScalarLogicalExpr>(branch.t));
+ Put(" ? ");
+ Walk(std::get<common::Indirection<Expr>>(branch.t));
+ Put(" : ");
+ }
+ Walk(std::get<common::Indirection<Expr>>(x.t));
+ Put(" )");
+ }
void Unparse(const Expr::ComplexConstructor &x) {
Put('('), Walk(x.t, ","), Put(')');
}
diff --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp
index 13c523da13c25..b69845fbb6be2 100644
--- a/flang/lib/Semantics/check-cuda.cpp
+++ b/flang/lib/Semantics/check-cuda.cpp
@@ -111,6 +111,21 @@ struct DeviceExprChecker
return parser::MessageFormattedText(
"'%s' may not be called in device code"_err_en_US, x.GetName());
}
+ template <typename T>
+ Result operator()(const evaluate::ConditionalExpr<T> &x) const {
+ // Check all conditions and values for device code violations
+ for (const auto &cond : x.conditions()) {
+ if (Result msg{(*this)(cond)}) {
+ return msg;
+ }
+ }
+ for (const auto &val : x.values()) {
+ if (Result msg{(*this)(val)}) {
+ return msg;
+ }
+ }
+ return Result{};
+ }
SemanticsContext &context_;
};
@@ -150,6 +165,21 @@ struct FindHostArray
}
return nullptr;
}
+ template <typename T>
+ Result operator()(const evaluate::ConditionalExpr<T> &x) const {
+ // Check all conditions and values for host arrays
+ for (const auto &cond : x.conditions()) {
+ if (Result hostArray{(*this)(cond)}) {
+ return hostArray;
+ }
+ }
+ for (const auto &val : x.values()) {
+ if (Result hostArray{(*this)(val)}) {
+ return hostArray;
+ }
+ }
+ return nullptr;
+ }
};
template <typename A>
diff --git a/flang/lib/Semantics/check-data.cpp b/flang/lib/Semantics/check-data.cpp
index 9dbbc163d85b3..c93711c8fc313 100644
--- a/flang/lib/Semantics/check-data.cpp
+++ b/flang/lib/Semantics/check-data.cpp
@@ -174,6 +174,21 @@ class DataVarChecker : public evaluate::AllTraverse<DataVarChecker, true> {
}
}
+ template <typename T> bool operator()(const evaluate::ConditionalExpr<T> &x) {
+ // Check all conditions and values
+ for (const auto &cond : x.conditions()) {
+ if (!(*this)(cond)) {
+ return false;
+ }
+ }
+ for (const auto &val : x.values()) {
+ if (!(*this)(val)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
private:
bool CheckSubscriptExpr(
const std::optional<evaluate::IndirectSubscriptIntegerExpr> &x) const {
diff --git a/flang/lib/Semantics/check-do-forall.cpp b/flang/lib/Semantics/check-do-forall.cpp
index bf92d920f282e..beb0f777ccf32 100644
--- a/flang/lib/Semantics/check-do-forall.cpp
+++ b/flang/lib/Semantics/check-do-forall.cpp
@@ -1143,6 +1143,17 @@ struct CollectActualArgumentsHelper
return Combine(ActualArgumentSet{arg},
CollectActualArgumentsHelper{}(arg.UnwrapExpr()));
}
+ template <typename T>
+ ActualArgumentSet operator()(const evaluate::ConditionalExpr<T> &x) const {
+ ActualArgumentSet result;
+ for (const auto &cond : x.conditions()) {
+ result = Combine(std::move(result), (*this)(cond));
+ }
+ for (const auto &val : x.values()) {
+ result = Combine(std::move(result), (*this)(val));
+ }
+ return result;
+ }
};
template <typename A> ActualArgumentSet CollectActualArguments(const A &x) {
diff --git a/flang/lib/Semantics/definable.cpp b/flang/lib/Semantics/definable.cpp
index de16422b89abd..581c1796b692e 100644
--- a/flang/lib/Semantics/definable.cpp
+++ b/flang/lib/Semantics/definable.cpp
@@ -305,6 +305,20 @@ class DuplicatedSubscriptFinder
}
return anyVector ? false : (*this)(aRef.base());
}
+ template <typename T> bool operator()(const evaluate::ConditionalExpr<T> &x) {
+ // Check all conditions and values for duplicated subscripts
+ for (const auto &cond : x.conditions()) {
+ if ((*this)(cond)) {
+ return true;
+ }
+ }
+ for (const auto &val : x.values()) {
+ if ((*this)(val)) {
+ return true;
+ }
+ }
+ return false;
+ }
private:
evaluate::FoldingContext &foldingContext_;
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index 457c5a3594f6d..a99fd6b0a94b3 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -3880,6 +3880,216 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::PercentLoc &x) {
return MakeFunctionRef(loc, ActualArguments{std::move(*arg)});
}
+// Helper to detect Expr<T> types (have ::Result typedef)
+template <typename T, typename = void>
+struct HasResultType : std::false_type {};
+template <typename T>
+struct HasResultType<T, std::void_t<typename T::Result>> : std::true_type {};
+
+MaybeExpr ExpressionAnalyzer::Analyze(const parser::ConditionalExpr &x) {
+ // Analyze all branches (condition ? value pairs)
+ const auto &branches{
+ std::get<std::list<parser::ConditionalExpr::Branch>>(x.t)};
+ const auto &elseExpr{std::get<common::Indirection<parser::Expr>>(x.t)};
+ std::vector<MaybeExpr> conditions;
+ std::vector<MaybeExpr> values;
+ for (const auto &branch : branches) {
+ const auto &condition{std::get<parser::ScalarLogicalExpr>(branch.t)};
+ const auto &value{std::get<common::Indirection<parser::Expr>>(branch.t)};
+ MaybeExpr condExpr{Analyze(condition.thing.thing.value())};
+ if (!condExpr) {
+ return std::nullopt;
+ }
+ if (!std::get_if<Expr<SomeLogical>>(&condExpr->u)) {
+ if (const auto type{condExpr->GetType()}) {
+ Say("Condition in conditional expression must be LOGICAL; have %s"_err_en_US,
+ type->AsFortran());
+ } else {
+ Say("Condition in conditional expression must be LOGICAL"_err_en_US);
+ }
+ return std::nullopt;
+ }
+ if (condExpr->Rank() != 0) {
+ Say("Condition in conditional expression must be scalar; have rank %d"_err_en_US,
+ condExpr->Rank());
+ return std::nullopt;
+ }
+ conditions.push_back(std::move(condExpr));
+ MaybeExpr valExpr{Analyze(value.value())};
+ if (!valExpr) {
+ return std::nullopt;
+ }
+ if (semantics::IsAssumedRank(*valExpr)) {
+ Say("An assumed-rank dummy argument may not be used as a value in a conditional expression"_err_en_US);
+ return std::nullopt;
+ }
+ values.push_back(std::move(valExpr));
+ }
+
+ // Analyze else expression
+ MaybeExpr elseValue{Analyze(elseExpr.value())};
+ if (!elseValue) {
+ return std::nullopt;
+ }
+ if (semantics::IsAssumedRank(*elseValue)) {
+ Say("An assumed-rank dummy argument may not be used as a value in a conditional expression"_err_en_US);
+ return std::nullopt;
+ }
+ values.push_back(std::move(elseValue));
+ CHECK(values.size() == conditions.size() + 1 &&
+ "values must have exactly one more element than conditions");
+
+ // F2023 C1004: Each expr shall have the same declared type, kind type
+ // parameters, and rank Reject typeless expressions (BOZ and NULL)
+ for (const auto &value : values) {
+ // BOZ arrays are auto-converted in array constructors, but bare BOZ are not
+ // allowed
+ if (std::holds_alternative<BOZLiteralConstant>(value->u)) {
+ Say("BOZ literal constant in conditional expression must have explicit "
+ "type "
+ "(e.g., INT(z'FF'), REAL(z'3F800000'))"_err_en_US);
+ return std::nullopt;
+ }
+ if (std::holds_alternative<evaluate::NullPointer>(value->u)) {
+ Say("NULL() not allowed in conditional expression (expressions must have declared type)"_err_en_US);
+ return std::nullopt;
+ }
+ }
+
+ // Determine result type from first value
+ const std::optional<DynamicType> resultType = values[0]->GetType();
+ if (!resultType) {
+ Say("Cannot determine type of conditional expression"_err_en_US);
+ return std::nullopt;
+ }
+
+ // Check that all values have the exact same type and kind (no promotion
+ // allowed)
+ const TypeCategory resultCategory{resultType->category()};
+ const int resultKind{
+ resultCategory != TypeCategory::Derived ? resultType->kind() : 0};
+ const int resultRank = values[0]->Rank();
+ // Check for polymorphic types (not yet supported in lowering)
+ if (resultCategory == TypeCategory::Derived && resultType->IsPolymorphic()) {
+ Say("Conditional expressions with polymorphic types (CLASS) are not yet supported"_err_en_US);
+ return std::nullopt;
+ }
+ for (const auto &value : values) {
+ // Check for coindexed objects
+ if (const auto dataRef{ExtractDataRef(value)}) {
+ if (ExtractCoarrayRef(*dataRef)) {
+ Say("Conditional expression values may not be coindexed"_err_en_US);
+ return std::nullopt;
+ }
+ }
+ const auto valueType{value->GetType()};
+ if (!valueType) {
+ Say("Cannot determine type of expression in conditional expression"_err_en_US);
+ return std::nullopt;
+ }
+ const TypeCategory valueCategory{valueType->category()};
+ const int valueKind{
+ valueCategory != TypeCategory::Derived ? valueType->kind() : 0};
+ if (resultCategory != valueCategory ||
+ (resultCategory != TypeCategory::Derived && resultKind != valueKind)) {
+ Say("All values in conditional expression must have the same type and kind; have %s and %s"_err_en_US,
+ resultType->AsFortran(), valueType->AsFortran());
+ return std::nullopt;
+ }
+ // For derived types, check they are the exact same type (not just
+ // compatible)
+ if (resultCategory == TypeCategory::Derived) {
+ if (&resultType->GetDerivedTypeSpec().typeSymbol() !=
+ &valueType->GetDerivedTypeSpec().typeSymbol()) {
+ Say("All values in conditional expression must be the same derived type; have %s and %s"_err_en_US,
+ resultType->AsFortran(), valueType->AsFortran());
+ return std::nullopt;
+ }
+ }
+ const int valueRank{value->Rank()};
+ if (resultRank != valueRank) {
+ Say("All values in conditional expression must have the same rank; have rank %d and %d"_err_en_US,
+ resultRank, valueRank);
+ return std::nullopt;
+ }
+ }
+
+ // Dispatch on the runtime type of values[0] to build the appropriately
+ // typed ConditionalExpr, with nested visitation to unwrap category->specific
+ // types.
+ return common::visit(
+ common::visitors{
+ [&](const BOZLiteralConstant &) -> MaybeExpr {
+ DIE("BOZ literal should have been eliminated by type validation");
+ },
+ [&](Expr<SomeDerived> &&derivedExpr) -> MaybeExpr {
+ std::vector<Expr<SomeLogical>> typedConditions;
+ typedConditions.reserve(conditions.size());
+ for (auto &cond : conditions) {
+ auto *logicalExpr{std::get_if<Expr<SomeLogical>>(&cond->u)};
+ CHECK(logicalExpr && "Condition should be SomeLogical");
+ typedConditions.emplace_back(std::move(*logicalExpr));
+ }
+ std::vector<Expr<SomeDerived>> typedValues;
+ typedValues.reserve(values.size());
+ // Use the moved-in first value directly, then process remaining
+ // values
+ typedValues.emplace_back(std::move(derivedExpr));
+ for (auto &val : llvm::drop_begin(values, 1)) {
+ auto *derivedVal{std::get_if<Expr<SomeDerived>>(&val->u)};
+ CHECK(derivedVal && "Value should be SomeDerived");
+ typedValues.emplace_back(std::move(*derivedVal));
+ }
+ return AsGenericExpr(
+ Expr<SomeDerived>{evaluate::ConditionalExpr<SomeDerived>{
+ std::move(typedConditions), std::move(typedValues)}});
+ },
+ [&](auto &&categoryExpr) -> MaybeExpr {
+ using CategoryType = std::decay_t<decltype(categoryExpr)>;
+ if constexpr (std::is_same_v<CategoryType, TypelessExpression> ||
+ std::is_same_v<CategoryType, Expr<SomeDerived>> ||
+ std::is_same_v<CategoryType, NullPointer> ||
+ std::is_same_v<CategoryType, ProcedureDesignator> ||
+ std::is_same_v<CategoryType, ProcedureRef>) {
+ DIE("Invalid expression type in conditional expression");
+ } else if constexpr (!HasResultType<CategoryType>::value) {
+ DIE("Unexpected bare constant type in conditional expression");
+ } else {
+ return common::visit(
+ [&](auto &&specificExpr) -> MaybeExpr {
+ using SpecificType = std::decay_t<decltype(specificExpr)>;
+ using T = typename SpecificType::Result;
+ std::vector<Expr<SomeLogical>> typedConditions;
+ typedConditions.reserve(conditions.size());
+ for (auto &cond : conditions) {
+ auto *logicalExpr{
+ std::get_if<Expr<SomeLogical>>(&cond->u)};
+ CHECK(logicalExpr && "Condition should be SomeLogical");
+ typedConditions.emplace_back(std::move(*logicalExpr));
+ }
+ std::vector<Expr<T>> typedValues;
+ typedValues.reserve(values.size());
+ // Use the moved-in first value directly, then process
+ // remaining values
+ typedValues.emplace_back(std::move(specificExpr));
+ for (auto &val : llvm::drop_begin(values, 1)) {
+ auto *catExpr{std::get_if<CategoryType>(&val->u)};
+ CHECK(catExpr && "Value should be CategoryType");
+ auto *specificVal{std::get_if<Expr<T>>(&catExpr->u)};
+ CHECK(specificVal && "Value should be Expr<T>");
+ typedValues.emplace_back(std::move(*specificVal));
+ }
+ return AsGenericExpr(CategoryType{Expr<T>{
+ evaluate::ConditionalExpr<T>{std::move(typedConditions),
+ std::move(typedValues)}}});
+ },
+ categoryExpr.u);
+ }
+ },
+ },
+ std::move(values[0]->u));
+}
+
MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::DefinedUnary &x) {
const auto &name{std::get<parser::DefinedOpName>(x.t).v};
ArgumentAnalyzer analyzer{*this, name.source};
@@ -5144,6 +5354,12 @@ std::optional<ActualArgument> ArgumentAnalyzer::AnalyzeExpr(
}
context_.SayAt(expr.source,
"TYPE(*) dummy argument may only be used as an actual argument"_err_en_US);
+ } else if (isProcedureCall_ &&
+ std::holds_alternative<parser::ConditionalExpr>(expr.u)) {
+ // Check parse tree before analysis to avoid wasted work
+ context_.SayAt(expr.source,
+ "Conditional expressions are not yet supported as actual arguments"_err_en_US);
+ return std::nullopt;
} else if (MaybeExpr argExpr{AnalyzeExprOrWholeAssumedSizeArray(expr)}) {
if (isProcedureCall_ || !IsProcedureDesignator(*argExpr)) {
// Pad Hollerith actual argument with spaces up to a multiple of 8
diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp
index b553fe874a378..326f79b86a694 100644
--- a/flang/lib/Semantics/openmp-utils.cpp
+++ b/flang/lib/Semantics/openmp-utils.cpp
@@ -298,6 +298,12 @@ struct LogicalConstantVistor : public evaluate::Traverse<LogicalConstantVistor,
}
}
}
+
+ template <typename T>
+ Result operator()(const evaluate::ConditionalExpr<T> &) const {
+ // A conditional expression is not treated as a constant logical value.
+ return std::nullopt;
+ }
};
} // namespace
@@ -390,6 +396,25 @@ struct DesignatorCollector : public evaluate::Traverse<DesignatorCollector,
(moveAppend(v, std::move(results)), ...);
return v;
}
+
+ template <typename T>
+ Result operator()(const evaluate::ConditionalExpr<T> &x) const {
+ // Collect designators from all conditions and values
+ Result result;
+ for (const auto &cond : x.conditions()) {
+ Result condResult = (*this)(cond);
+ for (auto &s : condResult) {
+ result.push_back(std::move(s));
+ }
+ }
+ for (const auto &val : x.values()) {
+ Result valResult = (*this)(val);
+ for (auto &s : valResult) {
+ result.push_back(std::move(s));
+ }
+ }
+ return result;
+ }
};
std::vector<SomeExpr> GetAllDesignators(const SomeExpr &expr) {
diff --git a/flang/lib/Semantics/resolve-names-utils.cpp b/flang/lib/Semantics/resolve-names-utils.cpp
index ef34c89182f7f..e7ed72b2bfa34 100644
--- a/flang/lib/Semantics/resolve-names-utils.cpp
+++ b/flang/lib/Semantics/resolve-names-utils.cpp
@@ -694,6 +694,20 @@ class SymbolMapper : public evaluate::AnyTraverse<SymbolMapper, bool> {
}
return false;
}
+ template <typename T> bool operator()(const evaluate::ConditionalExpr<T> &x) {
+ // Map symbols in all conditions and values
+ for (const auto &cond : x.conditions()) {
+ if ((*this)(cond)) {
+ return true;
+ }
+ }
+ for (const auto &val : x.values()) {
+ if ((*this)(val)) {
+ return true;
+ }
+ }
+ return false;
+ }
void MapSymbolExprs(Symbol &);
Symbol *CopySymbol(const Symbol *);
diff --git a/flang/test/Parser/conditional-expr.f90 b/flang/test/Parser/conditional-expr.f90
new file mode 100644
index 0000000000000..1fddbe97a5ee7
--- /dev/null
+++ b/flang/test/Parser/conditional-expr.f90
@@ -0,0 +1,261 @@
+! RUN: %flang_fc1 -fdebug-unparse-no-sema %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -fdebug-dump-parse-tree-no-sema %s 2>&1 | FileCheck %s -check-prefix=TREE
+
+! Test parsing of conditional expressions (Fortran 2023 R1002)
+
+! Simple two-branch conditional
+subroutine simple_conditional(x, y, z)
+ integer :: x, y, z
+ ! CHECK-LABEL: simple_conditional
+ ! CHECK: z = ( x>5 ? y : 10 )
+ ! TREE: ConditionalExpr
+ ! TREE-NEXT: Branch
+ ! TREE-NEXT: Scalar -> Logical -> Expr
+ ! TREE: Expr -> Designator -> DataRef -> Name = 'y'
+ ! TREE: Expr -> LiteralConstant -> IntLiteralConstant = '10'
+ z = (x > 5 ? y : 10)
+end subroutine
+
+! Three-branch conditional (multiple conditions)
+subroutine multi_branch_conditional(x, y, z)
+ integer :: x, y, z
+ ! CHECK-LABEL: multi_branch_conditional
+ ! CHECK: z = ( x>10 ? 100 : y<5 ? 50 : 0 )
+ ! TREE: ConditionalExpr
+ ! TREE-NEXT: Branch
+ ! TREE-NEXT: Scalar -> Logical -> Expr
+ ! TREE: Expr -> LiteralConstant -> IntLiteralConstant = '100'
+ ! TREE: Branch
+ ! TREE-NEXT: Scalar -> Logical -> Expr
+ ! TREE: Expr -> LiteralConstant -> IntLiteralConstant = '50'
+ ! TREE: Expr -> LiteralConstant -> IntLiteralConstant = '0'
+ z = (x > 10 ? 100 : y < 5 ? 50 : 0)
+end subroutine
+
+! Nested conditionals
+subroutine nested_conditionals(x, y, w, z, flag1, flag2)
+ integer :: x, y, w, z
+ logical :: flag1, flag2
+ ! CHECK-LABEL: nested_conditionals
+ ! Nested in value position
+ ! CHECK: z = ( flag1 ? ( x>y ? x : y ) : 0 )
+ ! TREE: ConditionalExpr
+ ! TREE-NEXT: Branch
+ ! TREE-NEXT: Scalar -> Logical -> Expr
+ ! TREE: Expr -> ConditionalExpr
+ ! TREE-NEXT: Branch
+ ! TREE-NEXT: Scalar -> Logical -> Expr
+ ! TREE: Expr -> Designator -> DataRef -> Name = 'x'
+ ! TREE: Expr -> Designator -> DataRef -> Name = 'y'
+ ! TREE: Expr -> LiteralConstant -> IntLiteralConstant = '0'
+ z = (flag1 ? (x > y ? x : y) : 0)
+ ! Nested in condition
+ ! CHECK: z = ( ( x>5 ? flag1 : flag2 ) ? y : 10 )
+ z = ((x > 5 ? flag1 : flag2) ? y : 10)
+ ! Multiple nested
+ ! CHECK: z = ( x>10 ? ( y>20 ? 1 : 2 ) : ( w>30 ? 3 : 4 ) )
+ z = (x > 10 ? (y > 20 ? 1 : 2) : (w > 30 ? 3 : 4))
+end subroutine
+
+! Basic type conditionals
+subroutine basic_types(x, a, b, c, flag1, str1)
+ integer :: x
+ real :: a, b, c
+ logical :: flag1
+ character(len=10) :: str1
+ ! CHECK-LABEL: basic_types
+ ! Real type
+ ! CHECK: c = ( a>b ? a : b )
+ c = (a > b ? a : b)
+ ! Logical type
+ ! CHECK: flag1 = ( x>5 ? .TRUE. : .FALSE. )
+ flag1 = (x > 5 ? .true. : .false.)
+ ! Character type
+ ! CHECK: str1 = ( flag1 ? "HELLO" : "WORLD" )
+ str1 = (flag1 ? "HELLO" : "WORLD")
+end subroutine
+
+! Complex expressions in conditions and branches
+subroutine complex_expressions(x, y, z, flag1)
+ integer :: x, y, z
+ logical :: flag1
+ ! CHECK-LABEL: complex_expressions
+ ! Complex expressions in branches
+ ! CHECK: z = ( x>y ? x*2+1 : y*3-2 )
+ z = (x > y ? x*2+1 : y*3-2)
+ ! Complex logical condition
+ ! CHECK: z = ( x>5.AND.y<10 ? x+y : x-y )
+ z = (x > 5 .and. y < 10 ? x+y : x-y)
+ ! Logical NOT
+ ! CHECK: z = ( .NOT.flag1 ? x : y )
+ z = (.not. flag1 ? x : y)
+ ! Comparison chains
+ ! CHECK: z = ( x>5.AND.x<10 ? x : 0 )
+ z = (x > 5 .and. x < 10 ? x : 0)
+ ! Parenthesized expressions in branches
+ ! CHECK: z = ( x>5 ? (y+z) : (y-z) )
+ z = (x > 5 ? (y+z) : (y-z))
+end subroutine
+
+! Many-branch conditionals
+subroutine many_branches(x, z)
+ integer :: x, z
+ ! CHECK-LABEL: many_branches
+ ! Four branches
+ ! CHECK: z = ( x>10 ? 100 : x>5 ? 50 : x>0 ? 10 : 0 )
+ z = (x > 10 ? 100 : x > 5 ? 50 : x > 0 ? 10 : 0)
+ ! Five branches
+ ! CHECK: z = ( x>20 ? 1 : x>15 ? 2 : x>10 ? 3 : x>5 ? 4 : 5 )
+ z = (x > 20 ? 1 : x > 15 ? 2 : x > 10 ? 3 : x > 5 ? 4 : 5)
+end subroutine
+
+! Conditionals with arrays and functions
+subroutine arrays_and_functions(x, y, z, arr, flag1)
+ integer :: x, y, z, arr(5)
+ logical :: flag1
+ ! CHECK-LABEL: arrays_and_functions
+ ! Array element in conditional
+ ! CHECK: z = ( arr(1)>arr(2) ? arr(1) : arr(2) )
+ z = (arr(1) > arr(2) ? arr(1) : arr(2))
+ ! Function calls in conditional
+ ! CHECK: x = ( abs(y)>10 ? abs(y) : y )
+ x = (abs(y) > 10 ? abs(y) : y)
+ ! Array constructor elements
+ ! CHECK: arr(1:3) = [( flag1 ? x : y ), ( .NOT.flag1 ? x : y ), ( x>y ? x : y )]
+ arr(1:3) = [(flag1 ? x : y), (.not. flag1 ? x : y), (x > y ? x : y)]
+end subroutine
+
+! Literals in conditionals
+subroutine literals(x, z, a, c)
+ integer :: x, z
+ real :: a, c
+ ! CHECK-LABEL: literals
+ ! Real literals
+ ! CHECK: c = ( a>0.0 ? 1.5 : 2.5 )
+ c = (a > 0.0 ? 1.5 : 2.5)
+ ! Negative values
+ ! CHECK: z = ( x<0 ? -1 : 1 )
+ z = (x < 0 ? -1 : 1)
+end subroutine
+
+! Conditional in specification expression context
+function spec_expr_conditional(n, flag) result(res)
+ integer, intent(in) :: n
+ logical, intent(in) :: flag
+ integer :: res
+ ! CHECK-LABEL: spec_expr_conditional
+ ! CHECK: res = ( flag ? n*2 : n )
+ res = (flag ? n*2 : n)
+end function
+
+! Conditional with different integer kinds
+subroutine integer_kinds(cond)
+ integer(kind=4) :: i4a, i4b, i4c
+ integer(kind=8) :: i8a, i8b, i8c
+ logical :: cond
+ ! CHECK-LABEL: integer_kinds
+ ! CHECK: i4c = ( cond ? i4a : i4b )
+ i4c = (cond ? i4a : i4b)
+ ! CHECK: i8c = ( cond ? i8a : i8b )
+ i8c = (cond ? i8a : i8b)
+end subroutine
+
+! Conditional with different real kinds
+subroutine real_kinds(cond)
+ real(kind=4) :: r4a, r4b, r4c
+ real(kind=8) :: r8a, r8b, r8c
+ logical :: cond
+ ! CHECK-LABEL: real_kinds
+ ! CHECK: r4c = ( cond ? r4a : r4b )
+ r4c = (cond ? r4a : r4b)
+ ! CHECK: r8c = ( cond ? r8a : r8b )
+ r8c = (cond ? r8a : r8b)
+end subroutine
+
+! Conditional in various statement contexts
+subroutine statement_contexts(flag)
+ integer :: x, y, arr(10)
+ logical :: flag
+ ! CHECK-LABEL: statement_contexts
+ ! In array constructor
+ ! CHECK: arr(1:3) = [1, ( flag ? x : y ), 3]
+ arr(1:3) = [1, (flag ? x : y), 3]
+ ! In if statement condition
+ ! CHECK: IF (( flag ? x : y )>5) THEN
+ if ((flag ? x : y) > 5) then
+ x = 1
+ end if
+ ! In print statement
+ ! CHECK: PRINT *, ( flag ? x : y )
+ print *, (flag ? x : y)
+ ! In assignment to array element
+ ! CHECK: arr(5) = ( flag ? x : y )
+ arr(5) = (flag ? x : y)
+end subroutine
+
+! Complex type conditionals
+subroutine complex_type(flag)
+ complex :: c1, c2, c3
+ complex(kind=8) :: c8a, c8b, c8c
+ logical :: flag
+ ! CHECK-LABEL: complex_type
+ ! CHECK: c3 = ( flag ? c1 : c2 )
+ c3 = (flag ? c1 : c2)
+ ! CHECK: c8c = ( flag ? c8a : c8b )
+ c8c = (flag ? c8a : c8b)
+ ! With complex literals
+ ! CHECK: c3 = ( flag ? (1.0,2.0) : (3.0,4.0) )
+ c3 = (flag ? (1.0, 2.0) : (3.0, 4.0))
+end subroutine
+
+! Array-valued conditionals (F2023 10.1.4)
+subroutine array_valued(flag)
+ integer :: arr1(5), arr2(5), arr3(5)
+ real :: mat1(3,3), mat2(3,3), mat3(3,3)
+ logical :: flag
+ ! CHECK-LABEL: array_valued
+ ! Whole array conditional
+ ! CHECK: arr3 = ( flag ? arr1 : arr2 )
+ ! TREE: ConditionalExpr
+ ! TREE-NEXT: Branch
+ ! TREE-NEXT: Scalar -> Logical -> Expr
+ ! TREE: Expr -> Designator -> DataRef -> Name = 'arr1'
+ ! TREE: Expr -> Designator -> DataRef -> Name = 'arr2'
+ arr3 = (flag ? arr1 : arr2)
+ ! Multidimensional array conditional
+ ! CHECK: mat3 = ( flag ? mat1 : mat2 )
+ mat3 = (flag ? mat1 : mat2)
+ ! Array section conditional
+ ! CHECK: arr3(1:3) = ( flag ? arr1(1:3) : arr2(1:3) )
+ arr3(1:3) = (flag ? arr1(1:3) : arr2(1:3))
+end subroutine
+
+! Derived type conditionals
+subroutine derived_types(flag)
+ type :: point
+ real :: x, y
+ end type
+ type(point) :: p1, p2, p3
+ logical :: flag
+ ! CHECK-LABEL: derived_types
+ ! CHECK: p3 = ( flag ? p1 : p2 )
+ p3 = (flag ? p1 : p2)
+end subroutine
+
+! Character with different lengths
+subroutine character_lengths(flag)
+ character(len=5) :: short1, short2
+ character(len=10) :: medium1, medium2
+ character(len=20) :: long_result
+ logical :: flag
+ ! CHECK-LABEL: character_lengths
+ ! Same length characters
+ ! CHECK: short1 = ( flag ? "HELLO" : "WORLD" )
+ short1 = (flag ? "HELLO" : "WORLD")
+ ! Different length literals (type conformance rules apply)
+ ! CHECK: long_result = ( flag ? "SHORT" : "MUCH LONGER STRING" )
+ long_result = (flag ? "SHORT" : "MUCH LONGER STRING")
+ ! Mixed variables and literals
+ ! CHECK: medium1 = ( flag ? short1 : medium2 )
+ medium1 = (flag ? short1 : medium2)
+end subroutine
diff --git a/flang/test/Semantics/conditional-expr.f90 b/flang/test/Semantics/conditional-expr.f90
new file mode 100644
index 0000000000000..2245cc942381b
--- /dev/null
+++ b/flang/test/Semantics/conditional-expr.f90
@@ -0,0 +1,365 @@
+! RUN: %python %S/test_errors.py %s %flang_fc1
+! Test semantic analysis of conditional expressions (Fortran 2023)
+
+! Valid cases with basic types
+subroutine valid_basic_types(flag)
+ logical :: flag
+ integer :: i1, i2, i3
+ real :: r1, r2, r3
+ complex :: c1, c2, c3
+ logical :: l1, l2, l3
+ character(len=5) :: ch1, ch2, ch3
+
+ ! INTEGER conditionals
+ i3 = (flag ? i1 : i2)
+
+ ! REAL conditionals
+ r3 = (flag ? r1 : r2)
+
+ ! COMPLEX conditionals
+ c3 = (flag ? c1 : c2)
+
+ ! LOGICAL conditionals
+ l3 = (flag ? l1 : l2)
+
+ ! CHARACTER conditionals
+ ch3 = (flag ? ch1 : ch2)
+end subroutine
+
+! Valid cases with same kind
+subroutine valid_same_kind(flag)
+ logical :: flag
+ integer(kind=4) :: i4a, i4b, i4c
+ integer(kind=8) :: i8a, i8b, i8c
+ real(kind=4) :: r4a, r4b, r4c
+ real(kind=8) :: r8a, r8b, r8c
+
+ ! Same kind - valid
+ i4c = (flag ? i4a : i4b)
+ i8c = (flag ? i8a : i8b)
+ r4c = (flag ? r4a : r4b)
+ r8c = (flag ? r8a : r8b)
+end subroutine
+
+! Valid cases with literals
+subroutine valid_literals(flag)
+ logical :: flag
+ integer :: i
+ real :: r
+ character(len=10) :: ch
+
+ i = (flag ? 10 : 20)
+ r = (flag ? 1.0 : 2.0)
+ ch = (flag ? "HELLO" : "WORLD")
+end subroutine
+
+! Valid cases with nested conditionals
+subroutine valid_nested(flag1, flag2, x, y, z, w)
+ logical :: flag1, flag2
+ integer :: x, y, z, w, result
+
+ ! Nested in value position
+ result = (flag1 ? (flag2 ? x : y) : z)
+
+ ! Nested in condition (condition is logical)
+ result = ((x > y ? flag1 : flag2) ? w : z)
+
+ ! Multi-branch
+ result = (x > 10 ? 100 : x > 5 ? 50 : 0)
+end subroutine
+
+! Valid cases with arrays
+subroutine valid_arrays(flag)
+ logical :: flag
+ integer :: arr1(10), arr2(10), arr3(10)
+ real :: mat1(3,3), mat2(3,3), mat3(3,3)
+
+ ! Whole array conditional
+ arr3 = (flag ? arr1 : arr2)
+
+ ! Multidimensional arrays
+ mat3 = (flag ? mat1 : mat2)
+
+ ! Array sections
+ arr3(1:5) = (flag ? arr1(1:5) : arr2(1:5))
+end subroutine
+
+! Valid cases with derived types
+subroutine valid_derived_types(flag)
+ type :: point
+ real :: x, y
+ end type
+
+ logical :: flag
+ type(point) :: p1, p2, p3
+
+ p3 = (flag ? p1 : p2)
+end subroutine
+
+! Valid cases with character lengths
+subroutine valid_character_lengths(flag)
+ logical :: flag
+ character(len=5) :: short1, short2, short3
+ character(len=10) :: medium
+ character(len=20) :: long
+
+ ! Same length
+ short3 = (flag ? short1 : short2)
+
+ ! Different lengths - padding/truncation applies
+ medium = (flag ? short1 : medium)
+ long = (flag ? short1 : "A LONGER STRING")
+end subroutine
+
+! Valid: deferred-length character scalars
+subroutine valid_deferred_length_character(flag)
+ logical :: flag
+ character(len=:), allocatable :: str1, str2, result
+
+ str1 = "SHORT"
+ str2 = "A MUCH LONGER STRING"
+ ! Result length is determined by selected branch
+ result = (flag ? str1 : str2)
+end subroutine
+
+! Valid: assumed-length character arguments
+subroutine valid_assumed_length_character(flag, str1, str2)
+ logical :: flag
+ character(len=*) :: str1, str2
+ character(len=100) :: result
+
+ result = (flag ? str1 : str2)
+end subroutine
+
+! Error: condition must be logical
+subroutine error_non_logical_condition()
+ integer :: i, x, y
+ real :: r
+ character :: ch
+
+ !ERROR: Condition in conditional expression must be LOGICAL; have INTEGER(4)
+ i = (i ? x : y)
+
+ !ERROR: Condition in conditional expression must be LOGICAL; have REAL(4)
+ i = (r ? x : y)
+
+ !ERROR: Condition in conditional expression must be LOGICAL; have CHARACTER(KIND=1,LEN=1_8)
+ i = (ch ? x : y)
+end subroutine
+
+! Error: type mismatch between branches
+subroutine error_type_mismatch(flag)
+ logical :: flag
+ integer :: i1, i2
+ real :: r
+ character :: ch
+ complex :: c
+
+ !ERROR: All values in conditional expression must have the same type and kind; have INTEGER(4) and REAL(4)
+ i1 = (flag ? i2 : r)
+
+ !ERROR: All values in conditional expression must have the same type and kind; have INTEGER(4) and CHARACTER(KIND=1,LEN=1_8)
+ i1 = (flag ? i2 : ch)
+
+ !ERROR: All values in conditional expression must have the same type and kind; have REAL(4) and COMPLEX(4)
+ r = (flag ? r : c)
+
+ !ERROR: All values in conditional expression must have the same type and kind; have LOGICAL(4) and INTEGER(4)
+ flag = (flag ? flag : i1)
+end subroutine
+
+! Error: kind mismatch (F2023: C1004)
+subroutine error_kind_mismatch(flag)
+ logical :: flag
+ integer(kind=4) :: i4
+ integer(kind=8) :: i8
+ real(kind=4) :: r4
+ real(kind=8) :: r8
+ complex(kind=4) :: c4
+ complex(kind=8) :: c8
+
+ !ERROR: All values in conditional expression must have the same type and kind; have INTEGER(4) and INTEGER(8)
+ i4 = (flag ? i4 : i8)
+
+ !ERROR: All values in conditional expression must have the same type and kind; have REAL(4) and REAL(8)
+ r4 = (flag ? r4 : r8)
+
+ !ERROR: All values in conditional expression must have the same type and kind; have COMPLEX(4) and COMPLEX(8)
+ c4 = (flag ? c4 : c8)
+end subroutine
+
+! Error: derived type mismatch
+subroutine error_derived_type_mismatch(flag)
+ type :: type1
+ integer :: i
+ end type
+
+ type :: type2
+ integer :: i
+ end type
+
+ logical :: flag
+ type(type1) :: t1
+ type(type2) :: t2
+
+ !ERROR: All values in conditional expression must be the same derived type; have type1 and type2
+ t1 = (flag ? t1 : t2)
+end subroutine
+
+! Error: array rank mismatch
+subroutine error_array_rank_mismatch(flag)
+ logical :: flag
+ integer :: arr1(10), mat1(3,3), result(10)
+
+ !ERROR: All values in conditional expression must have the same rank; have rank 1 and 2
+ result = (flag ? arr1 : mat1)
+end subroutine
+
+! Error: scalar vs array mismatch
+subroutine error_scalar_array_mismatch(flag)
+ logical :: flag
+ integer :: scalar, arr(10), result(10)
+
+ !ERROR: All values in conditional expression must have the same rank; have rank 0 and 1
+ result = (flag ? scalar : arr)
+end subroutine
+
+! Error: condition must be scalar
+subroutine error_array_condition()
+ logical :: flags(5)
+ integer :: x(5), y(5), result(5)
+
+ !ERROR: Condition in conditional expression must be scalar; have rank 1
+ result = (flags ? x : y)
+end subroutine
+
+! Valid cases with intrinsic functions
+subroutine valid_intrinsic_functions(x, y, flag)
+ integer :: x, y
+ logical :: flag
+ integer :: result
+
+ result = (flag ? abs(x) : abs(y))
+ result = (flag ? max(x, y) : min(x, y))
+end subroutine
+
+! Valid: conditional in array constructor
+subroutine valid_in_array_constructor(flag, x, y)
+ logical :: flag
+ integer :: x, y, arr(3)
+
+ arr = [(flag ? x : y), (flag ? x + 1 : y + 1), (flag ? x + 2 : y + 2)]
+end subroutine
+
+! Valid: conditional in expression context
+subroutine valid_in_expression(flag, x, y)
+ logical :: flag
+ integer :: x, y, z
+
+ z = (flag ? x : y) + 10
+ z = 2 * (flag ? x : y)
+
+ if ((flag ? x : y) > 5) then
+ z = 1
+ end if
+end subroutine
+
+! Note: allocatable/pointer differences are handled by assignment semantics
+! The conditional expression just requires matching types
+
+! Valid: both branches allocatable
+subroutine valid_both_allocatable(flag)
+ logical :: flag
+ integer, allocatable :: alloc1, alloc2, result
+
+ allocate(result)
+ result = (flag ? alloc1 : alloc2)
+end subroutine
+
+! Valid: both branches pointer
+subroutine valid_both_pointer(flag)
+ logical :: flag
+ integer, pointer :: ptr1, ptr2, result
+
+ result = (flag ? ptr1 : ptr2)
+end subroutine
+
+! Valid: elemental context
+elemental integer function conditional_elemental(flag, x, y)
+ logical, intent(in) :: flag
+ integer, intent(in) :: x, y
+
+ conditional_elemental = (flag ? x : y)
+end function
+
+! Valid: pure context
+pure integer function conditional_pure(flag, x, y)
+ logical, intent(in) :: flag
+ integer, intent(in) :: x, y
+
+ conditional_pure = (flag ? x : y)
+end function
+
+! Valid: recursive context
+recursive integer function conditional_recursive(n, flag, x, y) result(res)
+ integer, intent(in) :: n
+ logical, intent(in) :: flag
+ integer, intent(in) :: x, y
+
+ if (n <= 0) then
+ res = (flag ? x : y)
+ else
+ res = conditional_recursive(n - 1, flag, x, y)
+ end if
+end function
+
+! Valid: nested multi-branch
+subroutine valid_multi_branch(x)
+ integer :: x, result
+
+ ! Five-branch conditional
+ result = (x > 20 ? 1 : x > 15 ? 2 : x > 10 ? 3 : x > 5 ? 4 : 5)
+end subroutine
+
+! Error: polymorphic types not yet supported
+subroutine error_polymorphic(flag)
+ type :: base_t
+ integer :: i
+ end type
+
+ logical :: flag
+ class(base_t), allocatable :: poly1, poly2, result
+
+ !ERROR: Conditional expressions with polymorphic types (CLASS) are not yet supported
+ result = (flag ? poly1 : poly2)
+end subroutine
+
+! Error: mismatched character kinds
+subroutine error_character_kind_mismatch(flag)
+ logical :: flag
+ character(kind=1, len=5) :: ch1
+ character(kind=4, len=5) :: ch4
+
+ !ERROR: All values in conditional expression must have the same type and kind; have CHARACTER(KIND=1,LEN=5_8) and CHARACTER(KIND=4,LEN=5_8)
+ ch1 = (flag ? ch1 : ch4)
+end subroutine
+
+! Valid: optional arguments
+subroutine valid_optional_args(flag, opt_x, opt_y)
+ logical :: flag
+ integer, optional :: opt_x, opt_y
+ integer :: result
+
+ if (present(opt_x) .and. present(opt_y)) then
+ result = (flag ? opt_x : opt_y)
+ end if
+end subroutine
+
+! Valid: mix of expressions and designators
+subroutine valid_mixed_expressions(flag, x, y)
+ logical :: flag
+ integer :: x, y, result
+
+ result = (flag ? x + y : x - y)
+ result = (flag ? 2 * x : y / 2)
+end subroutine
>From 4c33be70639cb4935635e1785c1ee4c075c1e94d Mon Sep 17 00:00:00 2001
From: Caroline Newcombe <caroline.newcombe at hpe.com>
Date: Fri, 20 Mar 2026 14:47:32 -0500
Subject: [PATCH 2/2] [flang] Update conditional expressions to tree
representation
---
flang/include/flang/Evaluate/expression.h | 30 ++--
flang/include/flang/Evaluate/shape.h | 7 +-
flang/include/flang/Evaluate/tools.h | 24 ----
flang/include/flang/Evaluate/traverse.h | 2 +-
flang/include/flang/Semantics/dump-expr.h | 27 ++--
flang/lib/Evaluate/check-expression.cpp | 133 +-----------------
flang/lib/Evaluate/expression.cpp | 22 +--
flang/lib/Evaluate/fold-implementation.h | 17 +++
flang/lib/Evaluate/formatting.cpp | 17 ++-
flang/lib/Evaluate/shape.cpp | 63 +--------
flang/lib/Evaluate/tools.cpp | 121 ++--------------
flang/lib/Lower/IterationSpace.cpp | 11 +-
flang/lib/Lower/Support/Utils.cpp | 29 +---
flang/lib/Semantics/check-cuda.cpp | 30 ----
flang/lib/Semantics/check-data.cpp | 15 --
flang/lib/Semantics/check-do-forall.cpp | 11 --
flang/lib/Semantics/definable.cpp | 14 +-
flang/lib/Semantics/expression.cpp | 94 ++++---------
flang/lib/Semantics/openmp-utils.cpp | 19 ---
flang/lib/Semantics/resolve-names-utils.cpp | 14 --
flang/test/Evaluate/fold-conditional-expr.f90 | 30 ++++
flang/test/Semantics/conditional-expr.f90 | 8 +-
22 files changed, 160 insertions(+), 578 deletions(-)
create mode 100644 flang/test/Evaluate/fold-conditional-expr.f90
diff --git a/flang/include/flang/Evaluate/expression.h b/flang/include/flang/Evaluate/expression.h
index d46699cb7ac2c..68adcc5a698c9 100644
--- a/flang/include/flang/Evaluate/expression.h
+++ b/flang/include/flang/Evaluate/expression.h
@@ -397,27 +397,25 @@ template <typename T> class ConditionalExpr {
using Result = T;
CLASS_BOILERPLATE(ConditionalExpr)
ConditionalExpr(
- std::vector<Expr<SomeLogical>> &&conds, std::vector<Expr<Result>> &&vals)
- : conditions_{std::move(conds)}, values_{std::move(vals)} {
- CHECK(values_.size() == conditions_.size() + 1);
- }
+ Expr<SomeLogical> &&cond, Expr<Result> &&thenVal, Expr<Result> &&elseVal)
+ : condition_{std::move(cond)}, thenValue_{std::move(thenVal)},
+ elseValue_{std::move(elseVal)} {}
bool operator==(const ConditionalExpr &) const;
- const std::vector<Expr<SomeLogical>> &conditions() const {
- return conditions_;
- }
- std::vector<Expr<SomeLogical>> &conditions() { return conditions_; }
- const std::vector<Expr<Result>> &values() const { return values_; }
- std::vector<Expr<Result>> &values() { return values_; }
- int Rank() const { return values_.empty() ? 0 : values_.front().Rank(); }
- std::optional<DynamicType> GetType() const {
- return values_.empty() ? std::nullopt : values_.front().GetType();
- }
+ Expr<SomeLogical> &condition() { return condition_.value(); }
+ const Expr<SomeLogical> &condition() const { return condition_.value(); }
+ Expr<Result> &thenValue() { return thenValue_.value(); }
+ const Expr<Result> &thenValue() const { return thenValue_.value(); }
+ Expr<Result> &elseValue() { return elseValue_.value(); }
+ const Expr<Result> &elseValue() const { return elseValue_.value(); }
+ int Rank() const { return thenValue().Rank(); }
+ std::optional<DynamicType> GetType() const { return thenValue().GetType(); }
static constexpr int Corank() { return 0; }
llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
private:
- std::vector<Expr<SomeLogical>> conditions_; // size N
- std::vector<Expr<Result>> values_; // size N+1 (includes else)
+ common::CopyableIndirection<Expr<SomeLogical>> condition_;
+ common::CopyableIndirection<Expr<Result>> thenValue_;
+ common::CopyableIndirection<Expr<Result>> elseValue_;
};
// Array constructors
diff --git a/flang/include/flang/Evaluate/shape.h b/flang/include/flang/Evaluate/shape.h
index 3af78820f6c66..b409183471bb6 100644
--- a/flang/include/flang/Evaluate/shape.h
+++ b/flang/include/flang/Evaluate/shape.h
@@ -193,11 +193,8 @@ class GetShapeHelper
Result operator()(const ConditionalExpr<T> &conditional) const {
// Per F2023 10.1.4(7), the shape is determined by the selected branch,
// so return unknown extents for the rank.
- if (!conditional.values().empty()) {
- int rank{conditional.values().front().Rank()};
- return Shape(rank, std::nullopt);
- }
- return ScalarShape();
+ int rank{conditional.thenValue().Rank()};
+ return Shape(rank, std::nullopt);
}
template <typename D, typename R, typename LO, typename RO>
Result operator()(const Operation<D, R, LO, RO> &operation) const {
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 9ba29a5a2879c..09c942d8d21b6 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1074,16 +1074,6 @@ struct GetSymbolVectorHelper
Result operator()(const Component &) const;
Result operator()(const ArrayRef &) const;
Result operator()(const CoarrayRef &) const;
- template <typename T> Result operator()(const ConditionalExpr<T> &x) {
- Result result;
- for (const auto &cond : x.conditions()) {
- result = Combine(std::move(result), (*this)(cond));
- }
- for (const auto &val : x.values()) {
- result = Combine(std::move(result), (*this)(val));
- }
- return result;
- }
};
template <typename A> SymbolVector GetSymbolVector(const A &x) {
return GetSymbolVectorHelper{}(x);
@@ -1172,20 +1162,6 @@ class UnsafeToCopyVisitor : public AnyTraverse<UnsafeToCopyVisitor> {
return !admitPureCall_ || !procRef.proc().IsPure();
}
bool operator()(const CoarrayRef &) { return true; }
- template <typename T> bool operator()(const ConditionalExpr<T> &x) {
- // A conditional expression is unsafe to copy if any of its parts are unsafe
- for (const auto &condition : x.conditions()) {
- if ((*this)(condition)) {
- return true;
- }
- }
- for (const auto &value : x.values()) {
- if ((*this)(value)) {
- return true;
- }
- }
- return false;
- }
private:
bool admitPureCall_{false};
diff --git a/flang/include/flang/Evaluate/traverse.h b/flang/include/flang/Evaluate/traverse.h
index 306337274bf1f..44cfaa2a7073d 100644
--- a/flang/include/flang/Evaluate/traverse.h
+++ b/flang/include/flang/Evaluate/traverse.h
@@ -226,7 +226,7 @@ class Traverse {
}
// Conditional expressions (Fortran 2023)
template <typename T> Result operator()(const ConditionalExpr<T> &x) const {
- return Combine(x.conditions(), x.values());
+ return Combine(x.condition(), x.thenValue(), x.elseValue());
}
// Operations and wrappers
diff --git a/flang/include/flang/Semantics/dump-expr.h b/flang/include/flang/Semantics/dump-expr.h
index 5fbed77139958..d79a294258ff1 100644
--- a/flang/include/flang/Semantics/dump-expr.h
+++ b/flang/include/flang/Semantics/dump-expr.h
@@ -203,24 +203,15 @@ class DumpEvaluateExpr {
}
template <typename T> void Show(const evaluate::ConditionalExpr<T> &x) {
Indent("conditional expr "s + std::string(TypeOf<T>::name));
- const auto &conds = x.conditions();
- const auto &vals = x.values();
- // Show condition-value pairs
- for (const auto &[cond, val] : llvm::zip(conds, vals)) {
- Indent("branch");
- Indent("condition");
- Show(cond);
- Outdent();
- Indent("value");
- Show(val);
- Outdent();
- Outdent();
- }
- if (!vals.empty()) {
- Indent("default value");
- Show(vals.back());
- Outdent();
- }
+ Indent("condition");
+ Show(x.condition());
+ Outdent();
+ Indent("then");
+ Show(x.thenValue());
+ Outdent();
+ Indent("else");
+ Show(x.elseValue());
+ Outdent();
Outdent();
}
void Show(const evaluate::Relational<evaluate::SomeType> &x);
diff --git a/flang/lib/Evaluate/check-expression.cpp b/flang/lib/Evaluate/check-expression.cpp
index 7c5e7c129765c..96ab323375b5a 100644
--- a/flang/lib/Evaluate/check-expression.cpp
+++ b/flang/lib/Evaluate/check-expression.cpp
@@ -92,23 +92,6 @@ class IsConstantExprHelper
!sym.attrs().test(semantics::Attr::VALUE)));
}
- template <typename T>
- bool operator()(const ConditionalExpr<T> &conditional) const {
- // A conditional expression is constant if all its conditions and values are
- // constant
- for (const auto &condition : conditional.conditions()) {
- if (!(*this)(condition)) {
- return false;
- }
- }
- for (const auto &value : conditional.values()) {
- if (!(*this)(value)) {
- return false;
- }
- }
- return true;
- }
-
bool operator()(const ImpliedDoIndex &ido) const {
return acImpliedDos_.find(ido.name) != acImpliedDos_.end() || !context_ ||
context_->GetImpliedDo(ido.name).has_value();
@@ -246,20 +229,6 @@ struct IsActuallyConstantHelper {
template <typename T> bool operator()(const Parentheses<T> &x) {
return (*this)(x.left());
}
- template <typename T> bool operator()(const ConditionalExpr<T> &x) {
- // A conditional expression is actually constant if all its parts are
- for (const auto &condition : x.conditions()) {
- if (!(*this)(condition)) {
- return false;
- }
- }
- for (const auto &value : x.values()) {
- if (!(*this)(value)) {
- return false;
- }
- }
- return true;
- }
template <typename T> bool operator()(const Expr<T> &x) {
return common::visit([=](const auto &y) { return (*this)(y); }, x.u);
}
@@ -527,21 +496,6 @@ class SuspiciousRealLiteralFinder
}
return (*this)(x.left());
}
- template <typename T> bool operator()(const ConditionalExpr<T> &x) const {
- // Check all conditions and values in the conditional expression for
- // suspicious literals
- for (const auto &cond : x.conditions()) {
- if ((*this)(cond)) {
- return true;
- }
- }
- for (const auto &value : x.values()) {
- if ((*this)(value)) {
- return true;
- }
- }
- return false;
- }
private:
int kind_;
@@ -581,16 +535,6 @@ class InexactLiteralConversionFlagClearer
mut.set_isFromInexactLiteralConversion(false);
return false;
}
- template <typename T> bool operator()(const ConditionalExpr<T> &x) const {
- // Clear flags in all conditions and values of the conditional expression
- for (const auto &cond : x.conditions()) {
- (*this)(cond);
- }
- for (const auto &value : x.values()) {
- (*this)(value);
- }
- return false;
- }
};
// Converts, folds, and then checks type, rank, and shape of an
@@ -858,20 +802,6 @@ class CheckSpecificationExprHelper
return std::nullopt;
}
- template <typename T> Result operator()(const ConditionalExpr<T> &x) const {
- for (const auto &cond : x.conditions()) {
- if (auto result{(*this)(cond)}) {
- return result;
- }
- }
- for (const auto &val : x.values()) {
- if (auto result{(*this)(val)}) {
- return result;
- }
- }
- return std::nullopt;
- }
-
Result operator()(const ProcedureRef &x) const {
if (const auto *symbol{x.proc().GetSymbol()}) {
const Symbol &ultimate{symbol->GetUltimate()};
@@ -1268,28 +1198,9 @@ class IsContiguousHelper
Result operator()(const NullPointer &) const { return true; }
template <typename T> Result operator()(const ConditionalExpr<T> &x) {
- // Track contiguity across all possible runtime branches
- bool hasContiguous{false};
- bool hasNonContiguous{false};
- bool hasUnknown{false};
- for (const auto &val : x.values()) {
- auto result{(*this)(val)};
- if (!result) {
- hasUnknown = true;
- } else if (*result) {
- hasContiguous = true;
- } else {
- hasNonContiguous = true;
- }
- }
- // Return definite result only if all values have uniform contiguity
- if (hasUnknown || (hasContiguous && hasNonContiguous)) {
- return std::nullopt;
- } else if (hasContiguous) {
- return true;
- } else {
- return false;
- }
+ // Contiguity is not a meaningful characteristic of a conditional
+ // expression
+ return true;
}
private:
@@ -1485,20 +1396,6 @@ struct IsErrorExprHelper : public AnyTraverse<IsErrorExprHelper, bool> {
bool operator()(const SpecificIntrinsic &x) {
return x.name == IntrinsicProcTable::InvalidName;
}
-
- template <typename T> bool operator()(const ConditionalExpr<T> &x) {
- for (const auto &cond : x.conditions()) {
- if ((*this)(cond)) {
- return true;
- }
- }
- for (const auto &val : x.values()) {
- if ((*this)(val)) {
- return true;
- }
- }
- return false;
- }
};
template <typename A> bool IsErrorExpr(const A &x) {
@@ -1614,20 +1511,6 @@ class StmtFunctionChecker
return std::nullopt;
}
- template <typename T> Result operator()(const ConditionalExpr<T> &x) {
- for (const auto &cond : x.conditions()) {
- if (auto result{(*this)(cond)}) {
- return result;
- }
- }
- for (const auto &val : x.values()) {
- if (auto result{(*this)(val)}) {
- return result;
- }
- }
- return std::nullopt;
- }
-
private:
const Symbol &sf_;
FoldingContext &context_;
@@ -1889,14 +1772,8 @@ class CollectUsedSymbolValuesHelper
template <typename T> Result operator()(const ConditionalExpr<T> &condExpr) {
auto restorer{common::ScopedSet(isDefinition_, false)};
- Result result;
- for (const auto &cond : condExpr.conditions()) {
- result = Combine(std::move(result), (*this)(cond));
- }
- for (const auto &val : condExpr.values()) {
- result = Combine(std::move(result), (*this)(val));
- }
- return result;
+ return Combine((*this)(condExpr.condition()),
+ Combine((*this)(condExpr.thenValue()), (*this)(condExpr.elseValue())));
}
private:
diff --git a/flang/lib/Evaluate/expression.cpp b/flang/lib/Evaluate/expression.cpp
index e5289d14d63fc..aa5189b759334 100644
--- a/flang/lib/Evaluate/expression.cpp
+++ b/flang/lib/Evaluate/expression.cpp
@@ -64,23 +64,8 @@ Expr<Type<TypeCategory::Character, KIND>>::LEN() const {
}
return std::nullopt;
},
- [](const ConditionalExpr<Result> &c) -> T {
- // Return max of all branch lengths. If all have same constant
- // length, max folds to constant; otherwise signals deferred-length.
- std::optional<Expr<SubscriptInteger>> maxLen;
- for (const auto &value : c.values()) {
- if (auto len{value.LEN()}) {
- if (maxLen) {
- maxLen = Expr<SubscriptInteger>{Extremum<SubscriptInteger>{
- Ordering::Greater, std::move(*maxLen), std::move(*len)}};
- } else {
- maxLen = std::move(len);
- }
- } else {
- return std::nullopt;
- }
- }
- return maxLen;
+ [](const ConditionalExpr<Result> &) -> T {
+ return std::nullopt; // branch lengths may differ
},
[](const Designator<Result> &dr) { return dr.LEN(); },
[](const FunctionRef<Result> &fr) { return fr.LEN(); },
@@ -161,7 +146,8 @@ template <typename A> bool Extremum<A>::operator==(const Extremum &that) const {
template <typename A>
bool ConditionalExpr<A>::operator==(const ConditionalExpr &that) const {
- return conditions_ == that.conditions_ && values_ == that.values_;
+ return condition_ == that.condition_ && thenValue_ == that.thenValue_ &&
+ elseValue_ == that.elseValue_;
}
template <int KIND>
diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h
index 529e3a9ad5a08..9c4ef22f923a9 100644
--- a/flang/lib/Evaluate/fold-implementation.h
+++ b/flang/lib/Evaluate/fold-implementation.h
@@ -143,6 +143,8 @@ Expr<ImpliedDoIndex::Result> FoldOperation(
template <typename T>
Expr<T> FoldOperation(FoldingContext &, ArrayConstructor<T> &&);
Expr<SomeDerived> FoldOperation(FoldingContext &, StructureConstructor &&);
+template <typename T>
+Expr<T> FoldOperation(FoldingContext &, ConditionalExpr<T> &&);
template <typename T>
std::optional<Constant<T>> Folder<T>::GetNamedConstant(const Symbol &symbol0) {
@@ -2211,6 +2213,21 @@ Expr<T> FoldOperation(FoldingContext &context, RealToIntPower<T> &&x) {
x.right().u);
}
+template <typename T>
+Expr<T> FoldOperation(FoldingContext &context, ConditionalExpr<T> &&x) {
+ // Fold all sub-expressions first.
+ x.condition() = Fold(context, std::move(x.condition()));
+ x.thenValue() = Fold(context, std::move(x.thenValue()));
+ x.elseValue() = Fold(context, std::move(x.elseValue()));
+ // If the condition is a scalar logical constant, select the branch.
+ auto folded{Fold(
+ context, ConvertToType<LogicalResult>(Expr<SomeLogical>{x.condition()}))};
+ if (auto cst{GetScalarConstantValue<LogicalResult>(folded)}) {
+ return cst->IsTrue() ? std::move(x.thenValue()) : std::move(x.elseValue());
+ }
+ return Expr<T>{std::move(x)};
+}
+
template <typename T>
Expr<T> FoldOperation(FoldingContext &context, Extremum<T> &&x) {
if (auto array{ApplyElementwise(context, x,
diff --git a/flang/lib/Evaluate/formatting.cpp b/flang/lib/Evaluate/formatting.cpp
index 4c1002cf1cfc5..e4905c3e3c894 100644
--- a/flang/lib/Evaluate/formatting.cpp
+++ b/flang/lib/Evaluate/formatting.cpp
@@ -590,14 +590,21 @@ llvm::raw_ostream &ArrayConstructor<SomeDerived>::AsFortran(
template <typename T>
llvm::raw_ostream &ConditionalExpr<T>::AsFortran(llvm::raw_ostream &o) const {
o << '(';
- for (std::size_t i = 0; i < conditions_.size(); ++i) {
- conditions_[i].AsFortran(o);
+ const ConditionalExpr<T> *node{this};
+ while (true) {
+ node->condition().AsFortran(o);
o << " ? ";
- values_[i].AsFortran(o);
+ node->thenValue().AsFortran(o);
o << " : ";
+ // Continue chain for nested ConditionalExpr; else emit terminal value.
+ if (const auto *nested =
+ std::get_if<ConditionalExpr<T>>(&node->elseValue().u)) {
+ node = nested;
+ } else {
+ node->elseValue().AsFortran(o);
+ break;
+ }
}
- // Last value is the else clause
- values_.back().AsFortran(o);
return o << ')';
}
diff --git a/flang/lib/Evaluate/shape.cpp b/flang/lib/Evaluate/shape.cpp
index e37213041a7e4..27913c3559c71 100644
--- a/flang/lib/Evaluate/shape.cpp
+++ b/flang/lib/Evaluate/shape.cpp
@@ -221,63 +221,14 @@ ConstantSubscript GetSize(const ConstantSubscripts &shape) {
return size;
}
-// Helper visitor for ContainsAnyImpliedDoIndex
-struct ImpliedDoIndexVisitor : public AnyTraverse<ImpliedDoIndexVisitor> {
- using Base = AnyTraverse<ImpliedDoIndexVisitor>;
- ImpliedDoIndexVisitor() : Base{*this} {}
- using Base::operator();
- bool operator()(const ImpliedDoIndex &) { return true; }
-
- // Template helper for ConditionalExpr handlers
- template <typename T> bool CheckConditionalExpr(const ConditionalExpr<T> &x) {
- for (const auto &cond : x.conditions()) {
- if ((*this)(cond)) {
- return true;
- }
- }
- for (const auto &val : x.values()) {
- if ((*this)(val)) {
- return true;
- }
- }
- return false;
- }
-
- // ConditionalExpr handlers - check all conditions and values for implied DO
- // indices
- template <int KIND>
- bool operator()(const ConditionalExpr<Type<TypeCategory::Integer, KIND>> &x) {
- return CheckConditionalExpr(x);
- }
- template <int KIND>
- bool operator()(const ConditionalExpr<Type<TypeCategory::Logical, KIND>> &x) {
- return CheckConditionalExpr(x);
- }
- template <int KIND>
- bool operator()(const ConditionalExpr<Type<TypeCategory::Real, KIND>> &x) {
- return CheckConditionalExpr(x);
- }
- template <int KIND>
- bool operator()(const ConditionalExpr<Type<TypeCategory::Complex, KIND>> &x) {
- return CheckConditionalExpr(x);
- }
- template <int KIND>
- bool operator()(
- const ConditionalExpr<Type<TypeCategory::Unsigned, KIND>> &x) {
- return CheckConditionalExpr(x);
- }
- template <int KIND>
- bool operator()(
- const ConditionalExpr<Type<TypeCategory::Character, KIND>> &x) {
- return CheckConditionalExpr(x);
- }
- bool operator()(const ConditionalExpr<SomeKind<TypeCategory::Derived>> &x) {
- return CheckConditionalExpr(x);
- }
-};
-
bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
- return ImpliedDoIndexVisitor{}(expr);
+ struct MyVisitor : public AnyTraverse<MyVisitor> {
+ using Base = AnyTraverse<MyVisitor>;
+ MyVisitor() : Base{*this} {}
+ using Base::operator();
+ bool operator()(const ImpliedDoIndex &) { return true; }
+ };
+ return MyVisitor{}(expr);
}
// Determines lower bound on a dimension. This can be other than 1 only
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 4570d55d7e73d..ac785eaa241d5 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -1085,18 +1085,6 @@ struct CollectSymbolsHelper
semantics::UnorderedSymbolSet operator()(const Symbol &symbol) const {
return {symbol};
}
- template <typename T>
- semantics::UnorderedSymbolSet operator()(const ConditionalExpr<T> &x) {
- // Collect symbols from all conditions and values
- semantics::UnorderedSymbolSet result;
- for (const auto &cond : x.conditions()) {
- result.merge((*this)(cond));
- }
- for (const auto &val : x.values()) {
- result.merge((*this)(val));
- }
- return result;
- }
};
template <typename A> semantics::UnorderedSymbolSet CollectSymbols(const A &x) {
return CollectSymbolsHelper{}(x);
@@ -1130,18 +1118,6 @@ struct CollectCudaSymbolsHelper : public SetTraverse<CollectCudaSymbolsHelper,
semantics::UnorderedSymbolSet operator()(const ProcedureRef &) const {
return {};
}
- template <typename T>
- semantics::UnorderedSymbolSet operator()(const ConditionalExpr<T> &x) {
- // Collect CUDA symbols from all conditions and values
- semantics::UnorderedSymbolSet result;
- for (const auto &cond : x.conditions()) {
- result.merge((*this)(cond));
- }
- for (const auto &val : x.values()) {
- result.merge((*this)(val));
- }
- return result;
- }
};
template <typename A>
semantics::UnorderedSymbolSet CollectCudaSymbols(const A &x) {
@@ -1209,19 +1185,8 @@ struct HasVectorSubscriptHelper
bool operator()(const ProcedureRef &) const {
return false; // don't descend into function call arguments
}
- template <typename T> bool operator()(const ConditionalExpr<T> &x) {
- // Check if any condition or value has a vector subscript
- for (const auto &cond : x.conditions()) {
- if ((*this)(cond)) {
- return true;
- }
- }
- for (const auto &val : x.values()) {
- if ((*this)(val)) {
- return true;
- }
- }
- return false;
+ template <typename T> bool operator()(const ConditionalExpr<T> &) const {
+ return false; // not a variable designator
}
};
@@ -1249,20 +1214,6 @@ struct HasConstantHelper : public AnyTraverse<HasConstantHelper, bool,
}
// Only look for constant not in subscript.
bool operator()(const Subscript &) const { return false; }
- template <typename T> bool operator()(const ConditionalExpr<T> &x) {
- // Check if any condition or value has a constant
- for (const auto &cond : x.conditions()) {
- if ((*this)(cond)) {
- return true;
- }
- }
- for (const auto &val : x.values()) {
- if ((*this)(val)) {
- return true;
- }
- }
- return false;
- }
};
bool HasConstant(const Expr<SomeType> &expr) {
@@ -1277,21 +1228,6 @@ struct HasStructureComponentHelper
using Base::operator();
bool operator()(const Component &) const { return true; }
-
- template <typename T> bool operator()(const ConditionalExpr<T> &x) {
- // Check if any condition or value has a structure component
- for (const auto &cond : x.conditions()) {
- if ((*this)(cond)) {
- return true;
- }
- }
- for (const auto &val : x.values()) {
- if ((*this)(val)) {
- return true;
- }
- }
- return false;
- }
};
bool HasStructureComponent(const Expr<SomeType> &expr) {
@@ -1358,21 +1294,6 @@ class FindImpureCallHelper
return call.proc().GetName();
}
- template <typename T> Result operator()(const ConditionalExpr<T> &x) const {
- // Check if any condition or value contains an impure call
- for (const auto &cond : x.conditions()) {
- if (auto result{(*this)(cond)}) {
- return result;
- }
- }
- for (const auto &val : x.values()) {
- if (auto result{(*this)(val)}) {
- return result;
- }
- }
- return std::nullopt;
- }
-
private:
FoldingContext &context_;
};
@@ -1811,11 +1732,17 @@ struct ArgumentExtractor
template <typename T> Result operator()(const ConditionalExpr<T> &x) const {
// ConditionalExpr is a top-level operation; collect its immediate operands
Arguments args;
- for (const auto &cond : x.conditions()) {
- args.push_back(AsSomeExpr(cond));
- }
- for (const auto &val : x.values()) {
- args.push_back(AsSomeExpr(val));
+ const ConditionalExpr<T> *node{&x};
+ while (true) {
+ args.push_back(AsSomeExpr(node->condition()));
+ args.push_back(AsSomeExpr(node->thenValue()));
+ if (const auto *nested =
+ std::get_if<ConditionalExpr<T>>(&node->elseValue().u)) {
+ node = nested;
+ } else {
+ args.push_back(AsSomeExpr(node->elseValue()));
+ break;
+ }
}
return {Operator::Conditional, std::move(args)};
}
@@ -1989,11 +1916,7 @@ struct ConvertCollector
template <typename T> Result operator()(const ConditionalExpr<T> &x) const {
// For conditional expressions, collect conversions from all values only
- Result result;
- for (const auto &val : x.values()) {
- result = Combine(std::move(result), (*this)(val));
- }
- return result;
+ return Combine((*this)(x.thenValue()), (*this)(x.elseValue()));
}
template <typename... Rs>
@@ -2088,22 +2011,6 @@ struct VariableFinder : public evaluate::AnyTraverse<VariableFinder> {
return evaluate::AsGenericExpr(common::Clone(x)) == var;
}
- template <typename T>
- bool operator()(const evaluate::ConditionalExpr<T> &x) const {
- // Check if any condition or value contains the variable
- for (const auto &cond : x.conditions()) {
- if ((*this)(cond)) {
- return true;
- }
- }
- for (const auto &val : x.values()) {
- if ((*this)(val)) {
- return true;
- }
- }
- return false;
- }
-
private:
const SomeExpr &var;
};
diff --git a/flang/lib/Lower/IterationSpace.cpp b/flang/lib/Lower/IterationSpace.cpp
index 1f650a9fa5412..52a15223bc1e6 100644
--- a/flang/lib/Lower/IterationSpace.cpp
+++ b/flang/lib/Lower/IterationSpace.cpp
@@ -214,13 +214,10 @@ class ArrayBaseFinder {
}
template <typename T>
RT find(const Fortran::evaluate::ConditionalExpr<T> &x) {
- // Find array bases in all conditions and values
- for (const auto &cond : x.conditions()) {
- (void)find(cond);
- }
- for (const auto &val : x.values()) {
- (void)find(val);
- }
+ // Find array bases in condition and values
+ (void)find(x.condition());
+ (void)find(x.thenValue());
+ (void)find(x.elseValue());
return {};
}
RT find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) {
diff --git a/flang/lib/Lower/Support/Utils.cpp b/flang/lib/Lower/Support/Utils.cpp
index 230587f3b951c..ee5af11f1dbc2 100644
--- a/flang/lib/Lower/Support/Utils.cpp
+++ b/flang/lib/Lower/Support/Utils.cpp
@@ -160,15 +160,8 @@ class HashEvaluateExpr {
}
template <typename T>
static unsigned getHashValue(const Fortran::evaluate::ConditionalExpr<T> &x) {
- unsigned conds = 1u;
- for (const auto &cond : x.conditions()) {
- conds -= getHashValue(cond);
- }
- unsigned vals = 3u;
- for (const auto &val : x.values()) {
- vals += getHashValue(val);
- }
- return conds * 151u - vals;
+ return getHashValue(x.condition()) * 151u -
+ getHashValue(x.thenValue()) * 3u + getHashValue(x.elseValue());
}
template <Fortran::common::TypeCategory TC, int KIND>
static unsigned getHashValue(
@@ -432,21 +425,9 @@ class IsEqualEvaluateExpr {
static bool isEqual(const Fortran::evaluate::ConditionalExpr<T> &x,
const Fortran::evaluate::ConditionalExpr<T> &y) {
// Compare all conditions and values
- if (x.conditions().size() != y.conditions().size() ||
- x.values().size() != y.values().size()) {
- return false;
- }
- for (size_t i = 0; i < x.conditions().size(); ++i) {
- if (!isEqual(x.conditions()[i], y.conditions()[i])) {
- return false;
- }
- }
- for (size_t i = 0; i < x.values().size(); ++i) {
- if (!isEqual(x.values()[i], y.values()[i])) {
- return false;
- }
- }
- return true;
+ return isEqual(x.condition(), y.condition()) &&
+ isEqual(x.thenValue(), y.thenValue()) &&
+ isEqual(x.elseValue(), y.elseValue());
}
template <typename A>
static bool isEqual(const Fortran::evaluate::RealToIntPower<A> &x,
diff --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp
index b69845fbb6be2..13c523da13c25 100644
--- a/flang/lib/Semantics/check-cuda.cpp
+++ b/flang/lib/Semantics/check-cuda.cpp
@@ -111,21 +111,6 @@ struct DeviceExprChecker
return parser::MessageFormattedText(
"'%s' may not be called in device code"_err_en_US, x.GetName());
}
- template <typename T>
- Result operator()(const evaluate::ConditionalExpr<T> &x) const {
- // Check all conditions and values for device code violations
- for (const auto &cond : x.conditions()) {
- if (Result msg{(*this)(cond)}) {
- return msg;
- }
- }
- for (const auto &val : x.values()) {
- if (Result msg{(*this)(val)}) {
- return msg;
- }
- }
- return Result{};
- }
SemanticsContext &context_;
};
@@ -165,21 +150,6 @@ struct FindHostArray
}
return nullptr;
}
- template <typename T>
- Result operator()(const evaluate::ConditionalExpr<T> &x) const {
- // Check all conditions and values for host arrays
- for (const auto &cond : x.conditions()) {
- if (Result hostArray{(*this)(cond)}) {
- return hostArray;
- }
- }
- for (const auto &val : x.values()) {
- if (Result hostArray{(*this)(val)}) {
- return hostArray;
- }
- }
- return nullptr;
- }
};
template <typename A>
diff --git a/flang/lib/Semantics/check-data.cpp b/flang/lib/Semantics/check-data.cpp
index c93711c8fc313..9dbbc163d85b3 100644
--- a/flang/lib/Semantics/check-data.cpp
+++ b/flang/lib/Semantics/check-data.cpp
@@ -174,21 +174,6 @@ class DataVarChecker : public evaluate::AllTraverse<DataVarChecker, true> {
}
}
- template <typename T> bool operator()(const evaluate::ConditionalExpr<T> &x) {
- // Check all conditions and values
- for (const auto &cond : x.conditions()) {
- if (!(*this)(cond)) {
- return false;
- }
- }
- for (const auto &val : x.values()) {
- if (!(*this)(val)) {
- return false;
- }
- }
- return true;
- }
-
private:
bool CheckSubscriptExpr(
const std::optional<evaluate::IndirectSubscriptIntegerExpr> &x) const {
diff --git a/flang/lib/Semantics/check-do-forall.cpp b/flang/lib/Semantics/check-do-forall.cpp
index beb0f777ccf32..bf92d920f282e 100644
--- a/flang/lib/Semantics/check-do-forall.cpp
+++ b/flang/lib/Semantics/check-do-forall.cpp
@@ -1143,17 +1143,6 @@ struct CollectActualArgumentsHelper
return Combine(ActualArgumentSet{arg},
CollectActualArgumentsHelper{}(arg.UnwrapExpr()));
}
- template <typename T>
- ActualArgumentSet operator()(const evaluate::ConditionalExpr<T> &x) const {
- ActualArgumentSet result;
- for (const auto &cond : x.conditions()) {
- result = Combine(std::move(result), (*this)(cond));
- }
- for (const auto &val : x.values()) {
- result = Combine(std::move(result), (*this)(val));
- }
- return result;
- }
};
template <typename A> ActualArgumentSet CollectActualArguments(const A &x) {
diff --git a/flang/lib/Semantics/definable.cpp b/flang/lib/Semantics/definable.cpp
index 581c1796b692e..6f5eb0cb41ccd 100644
--- a/flang/lib/Semantics/definable.cpp
+++ b/flang/lib/Semantics/definable.cpp
@@ -305,18 +305,8 @@ class DuplicatedSubscriptFinder
}
return anyVector ? false : (*this)(aRef.base());
}
- template <typename T> bool operator()(const evaluate::ConditionalExpr<T> &x) {
- // Check all conditions and values for duplicated subscripts
- for (const auto &cond : x.conditions()) {
- if ((*this)(cond)) {
- return true;
- }
- }
- for (const auto &val : x.values()) {
- if ((*this)(val)) {
- return true;
- }
- }
+ template <typename T> bool operator()(const evaluate::ConditionalExpr<T> &) {
+ // A conditional expression is not a variable and cannot be definable.
return false;
}
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index a99fd6b0a94b3..374de9343badb 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -3896,24 +3896,10 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::ConditionalExpr &x) {
for (const auto &branch : branches) {
const auto &condition{std::get<parser::ScalarLogicalExpr>(branch.t)};
const auto &value{std::get<common::Indirection<parser::Expr>>(branch.t)};
- MaybeExpr condExpr{Analyze(condition.thing.thing.value())};
+ MaybeExpr condExpr{Analyze(condition)};
if (!condExpr) {
return std::nullopt;
}
- if (!std::get_if<Expr<SomeLogical>>(&condExpr->u)) {
- if (const auto type{condExpr->GetType()}) {
- Say("Condition in conditional expression must be LOGICAL; have %s"_err_en_US,
- type->AsFortran());
- } else {
- Say("Condition in conditional expression must be LOGICAL"_err_en_US);
- }
- return std::nullopt;
- }
- if (condExpr->Rank() != 0) {
- Say("Condition in conditional expression must be scalar; have rank %d"_err_en_US,
- condExpr->Rank());
- return std::nullopt;
- }
conditions.push_back(std::move(condExpr));
MaybeExpr valExpr{Analyze(value.value())};
if (!valExpr) {
@@ -3940,7 +3926,8 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::ConditionalExpr &x) {
"values must have exactly one more element than conditions");
// F2023 C1004: Each expr shall have the same declared type, kind type
- // parameters, and rank Reject typeless expressions (BOZ and NULL)
+ // parameters, and rank.
+ // Reject typeless expressions (BOZ and NULL)
for (const auto &value : values) {
// BOZ arrays are auto-converted in array constructors, but bare BOZ are not
// allowed
@@ -4014,35 +4001,25 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::ConditionalExpr &x) {
}
}
- // Dispatch on the runtime type of values[0] to build the appropriately
- // typed ConditionalExpr, with nested visitation to unwrap category->specific
- // types.
+ // Build a right-skewed ConditionalExpr tree from the else value.
+ // Dispatch on values.back() to recover the concrete type; std::get
+ // is safe because all types were validated above.
return common::visit(
common::visitors{
[&](const BOZLiteralConstant &) -> MaybeExpr {
DIE("BOZ literal should have been eliminated by type validation");
},
- [&](Expr<SomeDerived> &&derivedExpr) -> MaybeExpr {
- std::vector<Expr<SomeLogical>> typedConditions;
- typedConditions.reserve(conditions.size());
- for (auto &cond : conditions) {
- auto *logicalExpr{std::get_if<Expr<SomeLogical>>(&cond->u)};
- CHECK(logicalExpr && "Condition should be SomeLogical");
- typedConditions.emplace_back(std::move(*logicalExpr));
- }
- std::vector<Expr<SomeDerived>> typedValues;
- typedValues.reserve(values.size());
- // Use the moved-in first value directly, then process remaining
- // values
- typedValues.emplace_back(std::move(derivedExpr));
- for (auto &val : llvm::drop_begin(values, 1)) {
- auto *derivedVal{std::get_if<Expr<SomeDerived>>(&val->u)};
- CHECK(derivedVal && "Value should be SomeDerived");
- typedValues.emplace_back(std::move(*derivedVal));
+ [&](Expr<SomeDerived> &&elseExpr) -> MaybeExpr {
+ Expr<SomeDerived> result{std::move(elseExpr)};
+ for (int i = static_cast<int>(conditions.size()) - 1; i >= 0; --i) {
+ Expr<SomeLogical> cond{
+ std::move(std::get<Expr<SomeLogical>>(conditions[i]->u))};
+ Expr<SomeDerived> thenVal{
+ std::move(std::get<Expr<SomeDerived>>(values[i]->u))};
+ result = Expr<SomeDerived>{evaluate::ConditionalExpr<SomeDerived>{
+ std::move(cond), std::move(thenVal), std::move(result)}};
}
- return AsGenericExpr(
- Expr<SomeDerived>{evaluate::ConditionalExpr<SomeDerived>{
- std::move(typedConditions), std::move(typedValues)}});
+ return AsGenericExpr(std::move(result));
},
[&](auto &&categoryExpr) -> MaybeExpr {
using CategoryType = std::decay_t<decltype(categoryExpr)>;
@@ -4057,37 +4034,26 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::ConditionalExpr &x) {
} else {
return common::visit(
[&](auto &&specificExpr) -> MaybeExpr {
- using SpecificType = std::decay_t<decltype(specificExpr)>;
- using T = typename SpecificType::Result;
- std::vector<Expr<SomeLogical>> typedConditions;
- typedConditions.reserve(conditions.size());
- for (auto &cond : conditions) {
- auto *logicalExpr{
- std::get_if<Expr<SomeLogical>>(&cond->u)};
- CHECK(logicalExpr && "Condition should be SomeLogical");
- typedConditions.emplace_back(std::move(*logicalExpr));
- }
- std::vector<Expr<T>> typedValues;
- typedValues.reserve(values.size());
- // Use the moved-in first value directly, then process
- // remaining values
- typedValues.emplace_back(std::move(specificExpr));
- for (auto &val : llvm::drop_begin(values, 1)) {
- auto *catExpr{std::get_if<CategoryType>(&val->u)};
- CHECK(catExpr && "Value should be CategoryType");
- auto *specificVal{std::get_if<Expr<T>>(&catExpr->u)};
- CHECK(specificVal && "Value should be Expr<T>");
- typedValues.emplace_back(std::move(*specificVal));
+ using T =
+ typename std::decay_t<decltype(specificExpr)>::Result;
+ Expr<T> result{std::move(specificExpr)};
+ for (int i = static_cast<int>(conditions.size()) - 1;
+ i >= 0; --i) {
+ Expr<SomeLogical> cond{std::move(
+ std::get<Expr<SomeLogical>>(conditions[i]->u))};
+ Expr<T> thenVal{std::move(std::get<Expr<T>>(
+ std::get<CategoryType>(values[i]->u).u))};
+ result =
+ Expr<T>{evaluate::ConditionalExpr<T>{std::move(cond),
+ std::move(thenVal), std::move(result)}};
}
- return AsGenericExpr(CategoryType{Expr<T>{
- evaluate::ConditionalExpr<T>{std::move(typedConditions),
- std::move(typedValues)}}});
+ return AsGenericExpr(CategoryType{std::move(result)});
},
categoryExpr.u);
}
},
},
- std::move(values[0]->u));
+ std::move(values.back()->u));
}
MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::DefinedUnary &x) {
diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp
index 326f79b86a694..f3f034530af9f 100644
--- a/flang/lib/Semantics/openmp-utils.cpp
+++ b/flang/lib/Semantics/openmp-utils.cpp
@@ -396,25 +396,6 @@ struct DesignatorCollector : public evaluate::Traverse<DesignatorCollector,
(moveAppend(v, std::move(results)), ...);
return v;
}
-
- template <typename T>
- Result operator()(const evaluate::ConditionalExpr<T> &x) const {
- // Collect designators from all conditions and values
- Result result;
- for (const auto &cond : x.conditions()) {
- Result condResult = (*this)(cond);
- for (auto &s : condResult) {
- result.push_back(std::move(s));
- }
- }
- for (const auto &val : x.values()) {
- Result valResult = (*this)(val);
- for (auto &s : valResult) {
- result.push_back(std::move(s));
- }
- }
- return result;
- }
};
std::vector<SomeExpr> GetAllDesignators(const SomeExpr &expr) {
diff --git a/flang/lib/Semantics/resolve-names-utils.cpp b/flang/lib/Semantics/resolve-names-utils.cpp
index e7ed72b2bfa34..ef34c89182f7f 100644
--- a/flang/lib/Semantics/resolve-names-utils.cpp
+++ b/flang/lib/Semantics/resolve-names-utils.cpp
@@ -694,20 +694,6 @@ class SymbolMapper : public evaluate::AnyTraverse<SymbolMapper, bool> {
}
return false;
}
- template <typename T> bool operator()(const evaluate::ConditionalExpr<T> &x) {
- // Map symbols in all conditions and values
- for (const auto &cond : x.conditions()) {
- if ((*this)(cond)) {
- return true;
- }
- }
- for (const auto &val : x.values()) {
- if ((*this)(val)) {
- return true;
- }
- }
- return false;
- }
void MapSymbolExprs(Symbol &);
Symbol *CopySymbol(const Symbol *);
diff --git a/flang/test/Evaluate/fold-conditional-expr.f90 b/flang/test/Evaluate/fold-conditional-expr.f90
new file mode 100644
index 0000000000000..0337791ad1c5b
--- /dev/null
+++ b/flang/test/Evaluate/fold-conditional-expr.f90
@@ -0,0 +1,30 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of conditional expressions (Fortran 2023)
+module m
+ ! Basic scalar folding: constant condition selects the chosen branch.
+ logical, parameter :: test_true_int = (.true. ? 1 : 2) == 1
+ logical, parameter :: test_false_int = (.false. ? 1 : 2) == 2
+ logical, parameter :: test_true_real = (.true. ? 1.0 : 2.0) == 1.0
+ logical, parameter :: test_false_real = (.false. ? 1.0 : 2.0) == 2.0
+ logical, parameter :: test_true_logical = (.true. ? .true. : .false.)
+ logical, parameter :: test_false_logical = (.false. ? .false. : .true.)
+
+ ! Multi-branch: right-skewed tree folds correctly.
+ ! (.true. ? 10 : .false. ? 20 : 30) == 10
+ logical, parameter :: test_multi_first = (.true. ? 10 : .false. ? 20 : 30) == 10
+ ! (.false. ? 10 : .true. ? 20 : 30) == 20
+ logical, parameter :: test_multi_second = (.false. ? 10 : .true. ? 20 : 30) == 20
+ ! (.false. ? 10 : .false. ? 20 : 30) == 30
+ logical, parameter :: test_multi_third = (.false. ? 10 : .false. ? 20 : 30) == 30
+
+ ! Named constant expressions in branches are folded.
+ integer, parameter :: x = 5
+ logical, parameter :: test_branch_fold = (.true. ? x + 1 : x + 2) == 6
+
+ ! Named constant as condition.
+ logical, parameter :: cond = .true.
+ logical, parameter :: test_named_cond = (cond ? 42 : 0) == 42
+
+ ! Character: constant condition selects the branch value.
+ logical, parameter :: test_char = (.true. ? 'yes' : 'no') == 'yes'
+end module
diff --git a/flang/test/Semantics/conditional-expr.f90 b/flang/test/Semantics/conditional-expr.f90
index 2245cc942381b..200224c5a2ca0 100644
--- a/flang/test/Semantics/conditional-expr.f90
+++ b/flang/test/Semantics/conditional-expr.f90
@@ -137,13 +137,13 @@ subroutine error_non_logical_condition()
real :: r
character :: ch
- !ERROR: Condition in conditional expression must be LOGICAL; have INTEGER(4)
+ !ERROR: Must have LOGICAL type, but is INTEGER(4)
i = (i ? x : y)
- !ERROR: Condition in conditional expression must be LOGICAL; have REAL(4)
+ !ERROR: Must have LOGICAL type, but is REAL(4)
i = (r ? x : y)
- !ERROR: Condition in conditional expression must be LOGICAL; have CHARACTER(KIND=1,LEN=1_8)
+ !ERROR: Must have LOGICAL type, but is CHARACTER(KIND=1,LEN=1_8)
i = (ch ? x : y)
end subroutine
@@ -229,7 +229,7 @@ subroutine error_array_condition()
logical :: flags(5)
integer :: x(5), y(5), result(5)
- !ERROR: Condition in conditional expression must be scalar; have rank 1
+ !ERROR: Must be a scalar value, but is a rank-1 array
result = (flags ? x : y)
end subroutine
More information about the flang-commits
mailing list