[flang-commits] [flang] [flang][OpenMP] Add optional SemanticsContext parameter to loop utili… (PR #191231)
via flang-commits
flang-commits at lists.llvm.org
Thu Apr 9 09:08:13 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-semantics
Author: Krzysztof Parzyszek (kparzysz)
<details>
<summary>Changes</summary>
…ties
Some of the utilities may be used in symbol resolution which is before the expression analysis is done. In such situations, the typedExpr's normally stored in parser::Expr may not be available. To be able to obtain numeric values of expressions, using the analyzer directly may be necessary, which requires SemanticsContext to be provided.
---
Full diff: https://github.com/llvm/llvm-project/pull/191231.diff
2 Files Affected:
- (modified) flang/include/flang/Semantics/openmp-utils.h (+28-11)
- (modified) flang/lib/Semantics/openmp-utils.cpp (+46-28)
``````````diff
diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h
index b741c8eac3248..fad6f0db34f3a 100644
--- a/flang/include/flang/Semantics/openmp-utils.h
+++ b/flang/include/flang/Semantics/openmp-utils.h
@@ -118,6 +118,17 @@ std::optional<evaluate::DynamicType> GetDynamicType(
const parser::Expr &parserExpr);
std::optional<bool> GetLogicalValue(const SomeExpr &expr);
+std::optional<int64_t> GetIntValueFromExpr(
+ const parser::Expr &parserExpr, SemanticsContext *semaCtx = nullptr);
+
+template <typename T>
+std::optional<int64_t> GetIntValueFromExpr(
+ const T &wrappedExpr, SemanticsContext *semaCtx = nullptr) {
+ if (auto *parserExpr{parser::Unwrap<parser::Expr>(wrappedExpr)}) {
+ return GetIntValueFromExpr(*parserExpr, semaCtx);
+ }
+ return std::nullopt;
+}
std::optional<bool> IsContiguous(
SemanticsContext &semaCtx, const parser::OmpObject &object);
@@ -194,25 +205,29 @@ template <typename T> struct WithReason {
WithReason<int64_t> GetArgumentValueWithReason(
const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
- unsigned version);
+ unsigned version, SemanticsContext *semaCtx = nullptr);
WithReason<int64_t> GetNumArgumentsWithReason(
const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
- unsigned version);
+ unsigned version, SemanticsContext *semaCtx = nullptr);
WithReason<int64_t> GetHeightWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version);
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx = nullptr);
// Return the depth of the affected nests:
// {affected-depth, reason, must-be-perfect-nest}.
std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version);
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx = nullptr);
// Return the range of the affected nests in the sequence:
// {first, count, reason}.
// If the range is "the whole sequence", the return value will be {1, -1, ...}.
WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version);
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx = nullptr);
/// Return the depth in which all loops must be rectangular.
WithReason<int64_t> GetRectangularNestDepthWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version);
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx = nullptr);
// Count the required loop count from range. If count == -1, return -1,
// indicating all loops in the sequence.
@@ -223,11 +238,12 @@ std::optional<int64_t> GetRequiredCount(
struct LoopSequence {
LoopSequence(const parser::ExecutionPartConstruct &root, unsigned version,
- bool allowAllLoops = false);
+ bool allowAllLoops = false, SemanticsContext *semaCtx = nullptr);
template <typename R, typename = std::enable_if_t<is_range_v<R>>>
- LoopSequence(const R &range, unsigned version, bool allowAllLoops = false)
- : version_(version), allowAllLoops_(allowAllLoops) {
+ LoopSequence(const R &range, unsigned version, bool allowAllLoops = false,
+ SemanticsContext *semaCtx = nullptr)
+ : version_(version), allowAllLoops_(allowAllLoops), semaCtx_(semaCtx) {
entry_ = std::make_unique<Construct>(range, nullptr);
createChildrenFromRange(entry_->location);
precalculate();
@@ -265,8 +281,8 @@ struct LoopSequence {
private:
using Construct = ExecutionPartIterator::Construct;
- LoopSequence(
- std::unique_ptr<Construct> entry, unsigned version, bool allowAllLoops);
+ LoopSequence(std::unique_ptr<Construct> entry, unsigned version,
+ bool allowAllLoops, SemanticsContext *semaCtx = nullptr);
template <typename R, typename = std::enable_if_t<is_range_v<R>>>
void createChildrenFromRange(const R &range) {
@@ -318,6 +334,7 @@ struct LoopSequence {
bool allowAllLoops_;
std::unique_ptr<Construct> entry_;
std::vector<LoopSequence> children_;
+ SemanticsContext *semaCtx_{nullptr};
};
} // namespace omp
} // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp
index 13004fb4bab6b..a36cadc04ee65 100644
--- a/flang/lib/Semantics/openmp-utils.cpp
+++ b/flang/lib/Semantics/openmp-utils.cpp
@@ -342,6 +342,19 @@ std::optional<bool> GetLogicalValue(const SomeExpr &expr) {
return LogicalConstantVistor{}(expr);
}
+std::optional<int64_t> GetIntValueFromExpr(
+ const parser::Expr &parserExpr, SemanticsContext *semaCtx) {
+ if (auto value{GetIntValue(parserExpr)}) {
+ return value;
+ }
+ if (semaCtx) {
+ if (auto expr{evaluate::ExpressionAnalyzer{*semaCtx}.Analyze(parserExpr)}) {
+ return evaluate::ToInt64(expr);
+ }
+ }
+ return std::nullopt;
+}
+
namespace {
struct ContiguousHelper {
ContiguousHelper(SemanticsContext &context)
@@ -691,10 +704,10 @@ static SymbolVector SelectUsedSymbols(
WithReason<int64_t> GetArgumentValueWithReason(
const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
- unsigned version) {
+ unsigned version, SemanticsContext *semaCtx) {
if (auto *clause{parser::omp::FindClause(spec, clauseId)}) {
if (auto *expr{parser::Unwrap<parser::Expr>(clause->u)}) {
- if (auto value{GetIntValue(*expr)}) {
+ if (auto value{GetIntValueFromExpr(*expr, semaCtx)}) {
std::string name{GetUpperName(clauseId, version)};
Reason reason;
reason.Say(clause->source,
@@ -723,7 +736,7 @@ static WithReason<int64_t> GetNumArgumentsWithReasonForType(
WithReason<int64_t> GetNumArgumentsWithReason(
const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
- unsigned version) {
+ unsigned version, SemanticsContext *semaCtx) {
if (auto *clause{parser::omp::FindClause(spec, clauseId)}) {
std::string name{GetUpperName(clauseId, version)};
// Try the types used for list items.
@@ -744,7 +757,8 @@ WithReason<int64_t> GetNumArgumentsWithReason(
}
WithReason<int64_t> GetHeightWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version) {
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx) {
bool isFullUnroll{IsFullUnroll(spec)};
if (!isFullUnroll && !IsTransformableLoop(spec)) {
@@ -756,7 +770,7 @@ WithReason<int64_t> GetHeightWithReason(
switch (spec.DirName().v) {
case llvm::omp::Directive::OMPD_flatten:
if (auto &&value{GetArgumentValueWithReason(
- spec, llvm::omp::Clause::OMPC_depth, version)}) {
+ spec, llvm::omp::Clause::OMPC_depth, version, semaCtx)}) {
// FLATTEN DEPTH(n) replaces n loops with 1.
return {int64_t(1) - *value.value, std::move(value.reason)};
} else {
@@ -775,7 +789,7 @@ WithReason<int64_t> GetHeightWithReason(
case llvm::omp::Directive::OMPD_stripe:
case llvm::omp::Directive::OMPD_tile:
return GetNumArgumentsWithReason(
- spec, llvm::omp::Clause::OMPC_sizes, version);
+ spec, llvm::omp::Clause::OMPC_sizes, version, semaCtx);
case llvm::omp::Directive::OMPD_unroll:
if (isFullUnroll) {
Reason reason;
@@ -955,7 +969,8 @@ WithReason<T> operator+(T a, const WithReason<T> &b) {
// Return the depth of the affected nests:
// {affected-depth, must-be-perfect-nest}.
std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version) {
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx) {
llvm::omp::Directive dir{spec.DirId()};
bool allowsCollapse{llvm::omp::isAllowedClauseForDirective(
dir, llvm::omp::Clause::OMPC_collapse, version)};
@@ -964,9 +979,9 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
if (allowsCollapse || allowsOrdered) {
auto [ccount, creason]{GetArgumentValueWithReason(
- spec, llvm::omp::Clause::OMPC_collapse, version)};
+ spec, llvm::omp::Clause::OMPC_collapse, version, semaCtx)};
auto [ocount, oreason]{GetArgumentValueWithReason(
- spec, llvm::omp::Clause::OMPC_ordered, version)};
+ spec, llvm::omp::Clause::OMPC_ordered, version, semaCtx)};
// Ignore invalid arguments.
if (ccount <= 0) {
ccount = std::nullopt;
@@ -989,7 +1004,7 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
// Get the length of the argument list to PERMUTATION.
if (parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_permutation)) {
auto [num, reason]{GetNumArgumentsWithReason(
- spec, llvm::omp::Clause::OMPC_permutation, version)};
+ spec, llvm::omp::Clause::OMPC_permutation, version, semaCtx)};
return {{num, std::move(reason)}, true};
}
// PERMUTATION not specified, assume PERMUTATION(2, 1).
@@ -1004,14 +1019,14 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
case llvm::omp::Directive::OMPD_tile: {
// Get the length of the argument list to SIZES.
auto [num, reason]{GetNumArgumentsWithReason(
- spec, llvm::omp::Clause::OMPC_sizes, version)};
+ spec, llvm::omp::Clause::OMPC_sizes, version, semaCtx)};
return {{num, std::move(reason)}, true};
}
case llvm::omp::Directive::OMPD_fuse: {
// Get the value from the argument to DEPTH.
if (parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_depth)) {
auto [count, reason]{GetArgumentValueWithReason(
- spec, llvm::omp::Clause::OMPC_depth, version)};
+ spec, llvm::omp::Clause::OMPC_depth, version, semaCtx)};
return {{count, std::move(reason)}, true};
}
std::string name{GetUpperName(llvm::omp::Clause::OMPC_depth, version)};
@@ -1035,7 +1050,8 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
// Return the range of the affected nests in the sequence:
// {first, count, std::move(reason)}.
WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version) {
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx) {
llvm::omp::Directive dir{spec.DirId()};
if (dir == llvm::omp::Directive::OMPD_fuse) {
@@ -1043,8 +1059,8 @@ WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
if (auto *clause{
parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_looprange)}) {
auto &range{DEREF(parser::Unwrap<parser::OmpLooprangeClause>(clause->u))};
- std::optional<int64_t> first{GetIntValue(std::get<0>(range.t))};
- std::optional<int64_t> count{GetIntValue(std::get<1>(range.t))};
+ auto first{GetIntValueFromExpr(std::get<0>(range.t), semaCtx)};
+ auto count{GetIntValueFromExpr(std::get<1>(range.t), semaCtx)};
if (!first || !count || *first <= 0 || *count <= 0) {
return {};
}
@@ -1071,8 +1087,9 @@ WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
}
WithReason<int64_t> GetRectangularNestDepthWithReason(
- const parser::OmpDirectiveSpecification &spec, unsigned version) {
- auto [depth, _]{GetAffectedNestDepthWithReason(spec, version)};
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx) {
+ auto [depth, _]{GetAffectedNestDepthWithReason(spec, version, semaCtx)};
if (!depth) {
return {};
}
@@ -1195,8 +1212,8 @@ static_assert(HasSourceT<parser::ExecutionPartConstruct>::value);
#endif // EXPENSIVE_CHECKS
LoopSequence::LoopSequence(const parser::ExecutionPartConstruct &root,
- unsigned version, bool allowAllLoops)
- : version_(version), allowAllLoops_(allowAllLoops) {
+ unsigned version, bool allowAllLoops, SemanticsContext *semaCtx)
+ : version_(version), allowAllLoops_(allowAllLoops), semaCtx_(semaCtx) {
entry_ = createConstructEntry(root);
assert(entry_ && "Expecting loop like code");
@@ -1204,10 +1221,10 @@ LoopSequence::LoopSequence(const parser::ExecutionPartConstruct &root,
precalculate();
}
-LoopSequence::LoopSequence(
- std::unique_ptr<Construct> entry, unsigned version, bool allowAllLoops)
+LoopSequence::LoopSequence(std::unique_ptr<Construct> entry, unsigned version,
+ bool allowAllLoops, SemanticsContext *semaCtx)
: version_(version), allowAllLoops_(allowAllLoops),
- entry_(std::move(entry)) {
+ entry_(std::move(entry)), semaCtx_(semaCtx) {
createChildrenFromRange(entry_->location);
precalculate();
}
@@ -1240,7 +1257,7 @@ void LoopSequence::createChildrenFromRange(
for (auto &code : BlockRange(begin, end, BlockRange::Step::Over)) {
if (auto entry{createConstructEntry(code)}) {
children_.push_back(
- LoopSequence(std::move(entry), version_, allowAllLoops_));
+ LoopSequence(std::move(entry), version_, allowAllLoops_, semaCtx_));
// Even when DO WHILE et al are allowed to have entries, still treat
// them as invalid intervening code.
// Give it priority over other kinds of invalid interveninig code.
@@ -1346,7 +1363,7 @@ WithReason<int64_t> LoopSequence::calculateLength() const {
}
auto *loopRange{parser::Unwrap<parser::OmpLooprangeClause>(*clause)};
- std::optional<int64_t> count{GetIntValue(std::get<1>(loopRange->t))};
+ auto count{GetIntValueFromExpr(std::get<1>(loopRange->t), semaCtx_)};
if (!count || *count <= 0) {
return {};
}
@@ -1456,11 +1473,12 @@ LoopSequence::Depth LoopSequence::calculateDepths() const {
if (auto *clause{parser::omp::FindClause(
beginSpec, llvm::omp::Clause::OMPC_depth)}) {
auto &expr{parser::UnwrapRef<parser::Expr>(clause->u)};
- auto value{GetIntValue(expr)};
+ auto value{GetIntValueFromExpr(expr, semaCtx_)};
// The result is a perfect nest only if all loop in the sequence
// are fused.
if (value && nestedLength.value) {
- auto range{GetAffectedLoopRangeWithReason(beginSpec, version_)};
+ auto range{
+ GetAffectedLoopRangeWithReason(beginSpec, version_, semaCtx_)};
if (auto required{GetRequiredCount(range.value)}) {
if (*required == -1 || *required == *nestedLength.value) {
return Depth{value, value};
@@ -1508,7 +1526,7 @@ LoopSequence::Depth LoopSequence::calculateDepths() const {
beginSpec, llvm::omp::Clause::OMPC_partial)}) {
std::optional<int64_t> factor;
if (auto *expr{parser::Unwrap<parser::Expr>(clause->u)}) {
- factor = GetIntValue(*expr);
+ factor = GetIntValueFromExpr(*expr, semaCtx_);
}
// If it's a partial unroll, and the unroll count is 1, then this
// construct is a no-op.
@@ -1555,7 +1573,7 @@ WithReason<int64_t> LoopSequence::calculateHeight() const {
if (auto *omp{parser::Unwrap<parser::OpenMPLoopConstruct>(*entry_->owner)}) {
const parser::OmpDirectiveSpecification &beginSpec{omp->BeginDir()};
if (IsLoopTransforming(beginSpec.DirId())) {
- return GetHeightWithReason(beginSpec, version_);
+ return GetHeightWithReason(beginSpec, version_, semaCtx_);
}
return {0, Reason()};
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/191231
More information about the flang-commits
mailing list