[flang-commits] [flang] [flang][OpenMP] Add optional SemanticsContext parameter to loop utili… (PR #191231)
Krzysztof Parzyszek via flang-commits
flang-commits at lists.llvm.org
Thu Apr 9 09:07:29 PDT 2026
https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/191231
…ties
Some of the utilities may be used in symbol resolution which is before the expression analysis is done. In such situations, the typedExpr's normally stored in parser::Expr may not be available. To be able to obtain numeric values of expressions, using the analyzer directly may be necessary, which requires SemanticsContext to be provided.
>From 091bd3fbdf480e212f47061b899f36f540f59975 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Wed, 8 Apr 2026 14:02:17 -0500
Subject: [PATCH] [flang][OpenMP] Add optional SemanticsContext parameter to
loop utilities
Some of the utilities may be used in symbol resolution which is before
the expression analysis is done. In such situations, the typedExpr's
normally stored in parser::Expr may not be available.
To be able to obtain numeric values of expressions, using the analyzer
directly may be necessary, which requires SemanticsContext to be provided.
---
flang/include/flang/Semantics/openmp-utils.h | 39 ++++++++---
flang/lib/Semantics/openmp-utils.cpp | 74 ++++++++++++--------
2 files changed, 74 insertions(+), 39 deletions(-)
diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h
index b741c8eac3248..fad6f0db34f3a 100644
--- a/flang/include/flang/Semantics/openmp-utils.h
+++ b/flang/include/flang/Semantics/openmp-utils.h
@@ -118,6 +118,17 @@ std::optional<evaluate::DynamicType> GetDynamicType(
const parser::Expr &parserExpr);
std::optional<bool> GetLogicalValue(const SomeExpr &expr);
+std::optional<int64_t> GetIntValueFromExpr(
+ const parser::Expr &parserExpr, SemanticsContext *semaCtx = nullptr);
+
+template <typename T>
+std::optional<int64_t> GetIntValueFromExpr(
+ const T &wrappedExpr, SemanticsContext *semaCtx = nullptr) {
+ if (auto *parserExpr{parser::Unwrap<parser::Expr>(wrappedExpr)}) {
+ return GetIntValueFromExpr(*parserExpr, semaCtx);
+ }
+ return std::nullopt;
+}
std::optional<bool> IsContiguous(
SemanticsContext &semaCtx, const parser::OmpObject &object);
@@ -194,25 +205,29 @@ template <typename T> struct WithReason {
WithReason<int64_t> GetArgumentValueWithReason(
const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
- unsigned version);
+ unsigned version, SemanticsContext *semaCtx = nullptr);
WithReason<int64_t> GetNumArgumentsWithReason(
const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
- unsigned version);
+ unsigned version, SemanticsContext *semaCtx = nullptr);
WithReason<int64_t> GetHeightWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version);
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx = nullptr);
// Return the depth of the affected nests:
// {affected-depth, reason, must-be-perfect-nest}.
std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version);
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx = nullptr);
// Return the range of the affected nests in the sequence:
// {first, count, reason}.
// If the range is "the whole sequence", the return value will be {1, -1, ...}.
WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version);
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx = nullptr);
/// Return the depth in which all loops must be rectangular.
WithReason<int64_t> GetRectangularNestDepthWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version);
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx = nullptr);
// Count the required loop count from range. If count == -1, return -1,
// indicating all loops in the sequence.
@@ -223,11 +238,12 @@ std::optional<int64_t> GetRequiredCount(
struct LoopSequence {
LoopSequence(const parser::ExecutionPartConstruct &root, unsigned version,
- bool allowAllLoops = false);
+ bool allowAllLoops = false, SemanticsContext *semaCtx = nullptr);
template <typename R, typename = std::enable_if_t<is_range_v<R>>>
- LoopSequence(const R &range, unsigned version, bool allowAllLoops = false)
- : version_(version), allowAllLoops_(allowAllLoops) {
+ LoopSequence(const R &range, unsigned version, bool allowAllLoops = false,
+ SemanticsContext *semaCtx = nullptr)
+ : version_(version), allowAllLoops_(allowAllLoops), semaCtx_(semaCtx) {
entry_ = std::make_unique<Construct>(range, nullptr);
createChildrenFromRange(entry_->location);
precalculate();
@@ -265,8 +281,8 @@ struct LoopSequence {
private:
using Construct = ExecutionPartIterator::Construct;
- LoopSequence(
- std::unique_ptr<Construct> entry, unsigned version, bool allowAllLoops);
+ LoopSequence(std::unique_ptr<Construct> entry, unsigned version,
+ bool allowAllLoops, SemanticsContext *semaCtx = nullptr);
template <typename R, typename = std::enable_if_t<is_range_v<R>>>
void createChildrenFromRange(const R &range) {
@@ -318,6 +334,7 @@ struct LoopSequence {
bool allowAllLoops_;
std::unique_ptr<Construct> entry_;
std::vector<LoopSequence> children_;
+ SemanticsContext *semaCtx_{nullptr};
};
} // namespace omp
} // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp
index 13004fb4bab6b..a36cadc04ee65 100644
--- a/flang/lib/Semantics/openmp-utils.cpp
+++ b/flang/lib/Semantics/openmp-utils.cpp
@@ -342,6 +342,19 @@ std::optional<bool> GetLogicalValue(const SomeExpr &expr) {
return LogicalConstantVistor{}(expr);
}
+std::optional<int64_t> GetIntValueFromExpr(
+ const parser::Expr &parserExpr, SemanticsContext *semaCtx) {
+ if (auto value{GetIntValue(parserExpr)}) {
+ return value;
+ }
+ if (semaCtx) {
+ if (auto expr{evaluate::ExpressionAnalyzer{*semaCtx}.Analyze(parserExpr)}) {
+ return evaluate::ToInt64(expr);
+ }
+ }
+ return std::nullopt;
+}
+
namespace {
struct ContiguousHelper {
ContiguousHelper(SemanticsContext &context)
@@ -691,10 +704,10 @@ static SymbolVector SelectUsedSymbols(
WithReason<int64_t> GetArgumentValueWithReason(
const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
- unsigned version) {
+ unsigned version, SemanticsContext *semaCtx) {
if (auto *clause{parser::omp::FindClause(spec, clauseId)}) {
if (auto *expr{parser::Unwrap<parser::Expr>(clause->u)}) {
- if (auto value{GetIntValue(*expr)}) {
+ if (auto value{GetIntValueFromExpr(*expr, semaCtx)}) {
std::string name{GetUpperName(clauseId, version)};
Reason reason;
reason.Say(clause->source,
@@ -723,7 +736,7 @@ static WithReason<int64_t> GetNumArgumentsWithReasonForType(
WithReason<int64_t> GetNumArgumentsWithReason(
const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
- unsigned version) {
+ unsigned version, SemanticsContext *semaCtx) {
if (auto *clause{parser::omp::FindClause(spec, clauseId)}) {
std::string name{GetUpperName(clauseId, version)};
// Try the types used for list items.
@@ -744,7 +757,8 @@ WithReason<int64_t> GetNumArgumentsWithReason(
}
WithReason<int64_t> GetHeightWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version) {
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx) {
bool isFullUnroll{IsFullUnroll(spec)};
if (!isFullUnroll && !IsTransformableLoop(spec)) {
@@ -756,7 +770,7 @@ WithReason<int64_t> GetHeightWithReason(
switch (spec.DirName().v) {
case llvm::omp::Directive::OMPD_flatten:
if (auto &&value{GetArgumentValueWithReason(
- spec, llvm::omp::Clause::OMPC_depth, version)}) {
+ spec, llvm::omp::Clause::OMPC_depth, version, semaCtx)}) {
// FLATTEN DEPTH(n) replaces n loops with 1.
return {int64_t(1) - *value.value, std::move(value.reason)};
} else {
@@ -775,7 +789,7 @@ WithReason<int64_t> GetHeightWithReason(
case llvm::omp::Directive::OMPD_stripe:
case llvm::omp::Directive::OMPD_tile:
return GetNumArgumentsWithReason(
- spec, llvm::omp::Clause::OMPC_sizes, version);
+ spec, llvm::omp::Clause::OMPC_sizes, version, semaCtx);
case llvm::omp::Directive::OMPD_unroll:
if (isFullUnroll) {
Reason reason;
@@ -955,7 +969,8 @@ WithReason<T> operator+(T a, const WithReason<T> &b) {
// Return the depth of the affected nests:
// {affected-depth, must-be-perfect-nest}.
std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version) {
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx) {
llvm::omp::Directive dir{spec.DirId()};
bool allowsCollapse{llvm::omp::isAllowedClauseForDirective(
dir, llvm::omp::Clause::OMPC_collapse, version)};
@@ -964,9 +979,9 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
if (allowsCollapse || allowsOrdered) {
auto [ccount, creason]{GetArgumentValueWithReason(
- spec, llvm::omp::Clause::OMPC_collapse, version)};
+ spec, llvm::omp::Clause::OMPC_collapse, version, semaCtx)};
auto [ocount, oreason]{GetArgumentValueWithReason(
- spec, llvm::omp::Clause::OMPC_ordered, version)};
+ spec, llvm::omp::Clause::OMPC_ordered, version, semaCtx)};
// Ignore invalid arguments.
if (ccount <= 0) {
ccount = std::nullopt;
@@ -989,7 +1004,7 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
// Get the length of the argument list to PERMUTATION.
if (parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_permutation)) {
auto [num, reason]{GetNumArgumentsWithReason(
- spec, llvm::omp::Clause::OMPC_permutation, version)};
+ spec, llvm::omp::Clause::OMPC_permutation, version, semaCtx)};
return {{num, std::move(reason)}, true};
}
// PERMUTATION not specified, assume PERMUTATION(2, 1).
@@ -1004,14 +1019,14 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
case llvm::omp::Directive::OMPD_tile: {
// Get the length of the argument list to SIZES.
auto [num, reason]{GetNumArgumentsWithReason(
- spec, llvm::omp::Clause::OMPC_sizes, version)};
+ spec, llvm::omp::Clause::OMPC_sizes, version, semaCtx)};
return {{num, std::move(reason)}, true};
}
case llvm::omp::Directive::OMPD_fuse: {
// Get the value from the argument to DEPTH.
if (parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_depth)) {
auto [count, reason]{GetArgumentValueWithReason(
- spec, llvm::omp::Clause::OMPC_depth, version)};
+ spec, llvm::omp::Clause::OMPC_depth, version, semaCtx)};
return {{count, std::move(reason)}, true};
}
std::string name{GetUpperName(llvm::omp::Clause::OMPC_depth, version)};
@@ -1035,7 +1050,8 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
// Return the range of the affected nests in the sequence:
// {first, count, std::move(reason)}.
WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version) {
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx) {
llvm::omp::Directive dir{spec.DirId()};
if (dir == llvm::omp::Directive::OMPD_fuse) {
@@ -1043,8 +1059,8 @@ WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
if (auto *clause{
parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_looprange)}) {
auto &range{DEREF(parser::Unwrap<parser::OmpLooprangeClause>(clause->u))};
- std::optional<int64_t> first{GetIntValue(std::get<0>(range.t))};
- std::optional<int64_t> count{GetIntValue(std::get<1>(range.t))};
+ auto first{GetIntValueFromExpr(std::get<0>(range.t), semaCtx)};
+ auto count{GetIntValueFromExpr(std::get<1>(range.t), semaCtx)};
if (!first || !count || *first <= 0 || *count <= 0) {
return {};
}
@@ -1071,8 +1087,9 @@ WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
}
WithReason<int64_t> GetRectangularNestDepthWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version) {
- auto [depth, _]{GetAffectedNestDepthWithReason(spec, version)};
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx) {
+ auto [depth, _]{GetAffectedNestDepthWithReason(spec, version, semaCtx)};
if (!depth) {
return {};
}
@@ -1195,8 +1212,8 @@ static_assert(HasSourceT<parser::ExecutionPartConstruct>::value);
#endif // EXPENSIVE_CHECKS
LoopSequence::LoopSequence(const parser::ExecutionPartConstruct &root,
- unsigned version, bool allowAllLoops)
- : version_(version), allowAllLoops_(allowAllLoops) {
+ unsigned version, bool allowAllLoops, SemanticsContext *semaCtx)
+ : version_(version), allowAllLoops_(allowAllLoops), semaCtx_(semaCtx) {
entry_ = createConstructEntry(root);
assert(entry_ && "Expecting loop like code");
@@ -1204,10 +1221,10 @@ LoopSequence::LoopSequence(const parser::ExecutionPartConstruct &root,
precalculate();
}
-LoopSequence::LoopSequence(
- std::unique_ptr<Construct> entry, unsigned version, bool allowAllLoops)
+LoopSequence::LoopSequence(std::unique_ptr<Construct> entry, unsigned version,
+ bool allowAllLoops, SemanticsContext *semaCtx)
: version_(version), allowAllLoops_(allowAllLoops),
- entry_(std::move(entry)) {
+ entry_(std::move(entry)), semaCtx_(semaCtx) {
createChildrenFromRange(entry_->location);
precalculate();
}
@@ -1240,7 +1257,7 @@ void LoopSequence::createChildrenFromRange(
for (auto &code : BlockRange(begin, end, BlockRange::Step::Over)) {
if (auto entry{createConstructEntry(code)}) {
children_.push_back(
- LoopSequence(std::move(entry), version_, allowAllLoops_));
+ LoopSequence(std::move(entry), version_, allowAllLoops_, semaCtx_));
// Even when DO WHILE et al are allowed to have entries, still treat
// them as invalid intervening code.
// Give it priority over other kinds of invalid interveninig code.
@@ -1346,7 +1363,7 @@ WithReason<int64_t> LoopSequence::calculateLength() const {
}
auto *loopRange{parser::Unwrap<parser::OmpLooprangeClause>(*clause)};
- std::optional<int64_t> count{GetIntValue(std::get<1>(loopRange->t))};
+ auto count{GetIntValueFromExpr(std::get<1>(loopRange->t), semaCtx_)};
if (!count || *count <= 0) {
return {};
}
@@ -1456,11 +1473,12 @@ LoopSequence::Depth LoopSequence::calculateDepths() const {
if (auto *clause{parser::omp::FindClause(
beginSpec, llvm::omp::Clause::OMPC_depth)}) {
auto &expr{parser::UnwrapRef<parser::Expr>(clause->u)};
- auto value{GetIntValue(expr)};
+ auto value{GetIntValueFromExpr(expr, semaCtx_)};
// The result is a perfect nest only if all loop in the sequence
// are fused.
if (value && nestedLength.value) {
- auto range{GetAffectedLoopRangeWithReason(beginSpec, version_)};
+ auto range{
+ GetAffectedLoopRangeWithReason(beginSpec, version_, semaCtx_)};
if (auto required{GetRequiredCount(range.value)}) {
if (*required == -1 || *required == *nestedLength.value) {
return Depth{value, value};
@@ -1508,7 +1526,7 @@ LoopSequence::Depth LoopSequence::calculateDepths() const {
beginSpec, llvm::omp::Clause::OMPC_partial)}) {
std::optional<int64_t> factor;
if (auto *expr{parser::Unwrap<parser::Expr>(clause->u)}) {
- factor = GetIntValue(*expr);
+ factor = GetIntValueFromExpr(*expr, semaCtx_);
}
// If it's a partial unroll, and the unroll count is 1, then this
// construct is a no-op.
@@ -1555,7 +1573,7 @@ WithReason<int64_t> LoopSequence::calculateHeight() const {
if (auto *omp{parser::Unwrap<parser::OpenMPLoopConstruct>(*entry_->owner)}) {
const parser::OmpDirectiveSpecification &beginSpec{omp->BeginDir()};
if (IsLoopTransforming(beginSpec.DirId())) {
- return GetHeightWithReason(beginSpec, version_);
+ return GetHeightWithReason(beginSpec, version_, semaCtx_);
}
return {0, Reason()};
}
More information about the flang-commits
mailing list