[llvm-branch-commits] [flang] [flang][OpenMP] Identify affected loops, provide reason (PR #185299)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun Mar 8 08:43:33 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-semantics
Author: Krzysztof Parzyszek (kparzysz)
<details>
<summary>Changes</summary>
Implement utility functions to calculate the number of affected loops in a sequence or in a nest. Provide a reason for the returned value to be used in an explanatory message.
---
Full diff: https://github.com/llvm/llvm-project/pull/185299.diff
3 Files Affected:
- (modified) flang/include/flang/Semantics/openmp-utils.h (+54-5)
- (modified) flang/lib/Semantics/check-omp-loop.cpp (+13-24)
- (modified) flang/lib/Semantics/openmp-utils.cpp (+181-8)
``````````diff
diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h
index 1f1262738a341..a9801cf5b5753 100644
--- a/flang/include/flang/Semantics/openmp-utils.h
+++ b/flang/include/flang/Semantics/openmp-utils.h
@@ -22,11 +22,14 @@
#include "flang/Semantics/tools.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/Format.h"
#include <optional>
#include <string>
+#include <tuple>
#include <type_traits>
#include <utility>
+#include <vector>
namespace Fortran::semantics {
class Scope;
@@ -109,6 +112,34 @@ bool IsPointerAssignment(const evaluate::Assignment &x);
MaybeExpr MakeEvaluateExpr(const parser::OmpStylizedInstance &inp);
+// A representation of a "because" message. The `text` member is a formatted
+// message (i.e. without any printf-like formatting characters like %d, etc).
+// `source` is the location to which the "because" message will refer.
+struct Reason {
+ std::string text;
+ parser::CharBlock source;
+
+ Reason() = default;
+ Reason(const std::string t, parser::CharBlock s) : text(t), source(s) {}
+ operator bool() const { return !source.empty(); }
+};
+
+// Helper that formats the inputs into a std::string.
+template <typename ...Ts>
+static std::string format(const char *fmt, Ts... values) {
+ std::string str;
+ llvm::raw_string_ostream os(str);
+ os << llvm::format(fmt, values...);
+ return os.str();
+}
+
+std::pair<std::optional<int64_t>, Reason> GetArgumentValueWithReason(
+ const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
+ unsigned version);
+std::pair<std::optional<int64_t>, Reason> GetNumArgumentsWithReason(
+ const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
+ unsigned version);
+
bool IsLoopTransforming(llvm::omp::Directive dir);
bool IsFullUnroll(const parser::OpenMPLoopConstruct &x);
@@ -116,13 +147,29 @@ std::optional<int64_t> GetNumGeneratedNestsFrom(
const parser::ExecutionPartConstruct &epc,
std::optional<int64_t> nestedCount);
+// Return the depth of the affected nests:
+// {affected-depth, must-be-perfect-nest, reason}.
+std::tuple<std::optional<int64_t>, bool, Reason> GetAffectedNestDepthWithReason(
+ const parser::OpenMPLoopConstruct &x, unsigned version);
+// Return the range of the affected nests in the sequence:
+// {first, count, reason}.
+// If the range is "the whole sequence", the return value will be {1, -1, ...}.
+std::tuple<std::optional<int64_t>, std::optional<int64_t>, Reason>
+GetAffectedLoopRangeWithReason(
+ const parser::OpenMPLoopConstruct &x, unsigned version);
+
+// Count the required loop count from range. If count == -1, return -1,
+// indicating all loops in the sequence.
+std::optional<int64_t> GetRequiredCount(
+ std::optional<int64_t> first, std::optional<int64_t> count);
+
struct LoopSequence {
- LoopSequence(
- const parser::ExecutionPartConstruct &root, bool allowAllLoops = false);
+ LoopSequence(const parser::ExecutionPartConstruct &root, unsigned version,
+ bool allowAllLoops = false);
template <typename R, typename = std::enable_if_t<is_range_v<R>>>
- LoopSequence(const R &range, bool allowAllLoops = false)
- : allowAllLoops_(allowAllLoops) {
+ LoopSequence(const R &range, unsigned version, bool allowAllLoops = false)
+ : version_(version), allowAllLoops_(allowAllLoops) {
entry_ = std::make_unique<Construct>(range, nullptr);
createChildrenFromRange(entry_->location);
calculateEverything();
@@ -145,7 +192,8 @@ struct LoopSequence {
private:
using Construct = ExecutionPartIterator::Construct;
- LoopSequence(std::unique_ptr<Construct> entry, bool allowAllLoops);
+ LoopSequence(
+ std::unique_ptr<Construct> entry, unsigned version, bool allowAllLoops);
template <typename R, typename = std::enable_if_t<is_range_v<R>>>
void createChildrenFromRange(const R &range) {
@@ -182,6 +230,7 @@ struct LoopSequence {
Depth depth_;
// The core structure of the class:
+ unsigned version_; // Needed for GetXyzWithReason
bool allowAllLoops_;
std::unique_ptr<Construct> entry_;
std::vector<LoopSequence> children_;
diff --git a/flang/lib/Semantics/check-omp-loop.cpp b/flang/lib/Semantics/check-omp-loop.cpp
index 6c0d6afed8696..b17ece42f475a 100644
--- a/flang/lib/Semantics/check-omp-loop.cpp
+++ b/flang/lib/Semantics/check-omp-loop.cpp
@@ -37,10 +37,6 @@
#include <tuple>
#include <variant>
-namespace Fortran::semantics {
-static std::optional<int64_t> GetNumGeneratedNests(const parser::Block &block);
-} // namespace Fortran::semantics
-
namespace {
using namespace Fortran;
@@ -244,19 +240,10 @@ void OmpStructureChecker::CheckSIMDNest(const parser::OpenMPConstruct &c) {
}
}
-static std::optional<int64_t> GetNumGeneratedNests(const parser::Block &block) {
- // Count the number of loops in the associated block. If there are any
- // malformed construct in there, getting the number may be meaningless.
- // These issues will be diagnosed elsewhere, and we should not emit any
- // messages about a potentially incorrect loop count.
- // In such cases reset the count to nullopt. Once it becomes nullopt,
- // keep it that way.
- return LoopSequence(block, true).length();
-}
-
void OmpStructureChecker::CheckNestedConstruct(
const parser::OpenMPLoopConstruct &x) {
const parser::OmpDirectiveSpecification &beginSpec{x.BeginDir()};
+ unsigned version{context_.langOptions().OpenMPVersion};
// End-directive is not allowed in such cases:
// do 100 i = ...
@@ -298,9 +285,11 @@ void OmpStructureChecker::CheckNestedConstruct(
}
}
+ LoopSequence sequence(body, version, true);
+
// Check if a loop-nest-associated construct has only one top-level loop
// in it.
- if (std::optional<int64_t> numLoops{GetNumGeneratedNests(body)}) {
+ if (std::optional<int64_t> numLoops{sequence.length()}) {
if (*numLoops == 0) {
context_.Say(beginSpec.DirName().source,
"This construct should contain a DO-loop or a loop-nest-generating OpenMP construct"_err_en_US);
@@ -514,20 +503,20 @@ void OmpStructureChecker::CheckDistLinear(
void OmpStructureChecker::CheckLooprangeBounds(
const parser::OpenMPLoopConstruct &x) {
+ unsigned version{context_.langOptions().OpenMPVersion};
if (auto *clause{parser::omp::FindClause(
x.BeginDir(), llvm::omp::Clause::OMPC_looprange)}) {
auto *lrClause{parser::Unwrap<parser::OmpLooprangeClause>(clause)};
auto first{GetIntValue(std::get<0>(lrClause->t))};
auto count{GetIntValue(std::get<1>(lrClause->t))};
- if (!first || !count || *first <= 0 || *count <= 0) {
- return;
- }
- int64_t requiredCount{*first + *count - 1};
- if (auto loopCount{GetNumGeneratedNests(std::get<parser::Block>(x.t))}) {
- if (*loopCount < requiredCount) {
- context_.Say(clause->source,
- "The specified loop range requires %ld loops, but the loop sequence has a length of %ld"_err_en_US,
- requiredCount, *loopCount);
+ if (auto requiredCount{GetRequiredCount(first, count)}) {
+ LoopSequence sequence(std::get<parser::Block>(x.t), version, true);
+ if (auto loopCount{sequence.length()}) {
+ if (*loopCount < *requiredCount) {
+ context_.Say(clause->source,
+ "The specified loop range requires %ld loops, but the loop sequence has a length of %ld"_err_en_US,
+ *requiredCount, *loopCount);
+ }
}
}
}
diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp
index 8800d21c59cc1..63cd9b42e651c 100644
--- a/flang/lib/Semantics/openmp-utils.cpp
+++ b/flang/lib/Semantics/openmp-utils.cpp
@@ -528,6 +528,42 @@ MaybeExpr MakeEvaluateExpr(const parser::OmpStylizedInstance &inp) {
instance.u);
}
+std::pair<std::optional<int64_t>, Reason> GetArgumentValueWithReason(
+ const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
+ unsigned version) {
+ if (auto *clause{parser::omp::FindClause(spec, clauseId)}) {
+ if (auto *expr{parser::Unwrap<parser::Expr>(clause->u)}) {
+ if (auto value{GetIntValue(*expr)}) {
+ llvm::StringRef name{llvm::omp::getOpenMPClauseName(clauseId, version)};
+ Reason reason( //
+ format("%s clause was specified with argument %ld",
+ parser::ToUpperCaseLetters(name.str()).c_str(), *value),
+ clause->source);
+ return {*value, reason};
+ }
+ }
+ }
+ return {std::nullopt, Reason()};
+}
+
+std::pair<std::optional<int64_t>, Reason> GetNumArgumentsWithReason(
+ const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
+ unsigned version) {
+ if (auto *clause{parser::omp::FindClause(spec, clauseId)}) {
+ using ArgumentList = std::list<parser::ScalarIntExpr>;
+ if (auto *args{parser::Unwrap<ArgumentList>(clause->u)}) {
+ llvm::StringRef name{llvm::omp::getOpenMPClauseName(clauseId, version)};
+ auto num{static_cast<int64_t>(args->size())};
+ Reason reason( //
+ format("%s clause was specified with %ld arguments",
+ parser::ToUpperCaseLetters(name.str()).c_str(), num),
+ clause->source);
+ return {num, reason};
+ }
+ }
+ return {std::nullopt, Reason()};
+}
+
bool IsLoopTransforming(llvm::omp::Directive dir) {
switch (dir) {
// TODO case llvm::omp::Directive::OMPD_flatten:
@@ -770,9 +806,132 @@ std::optional<int64_t> GetNumGeneratedNestsFrom(
return 1;
}
-LoopSequence::LoopSequence(
- const parser::ExecutionPartConstruct &root, bool allowAllLoops)
- : allowAllLoops_(allowAllLoops) {
+// Return the depth of the affected nests:
+// {affected-depth, must-be-perfect-nest}.
+std::tuple<std::optional<int64_t>, bool, Reason> GetAffectedNestDepthWithReason(
+ const parser::OpenMPLoopConstruct &x, unsigned version) {
+ const parser::OmpDirectiveSpecification &beginSpec{x.BeginDir()};
+ llvm::omp::Directive dir{beginSpec.DirId()};
+ bool allowsCollapse{llvm::omp::isAllowedClauseForDirective(
+ dir, llvm::omp::Clause::OMPC_collapse, version)};
+ bool allowsOrdered{llvm::omp::isAllowedClauseForDirective(
+ dir, llvm::omp::Clause::OMPC_ordered, version)};
+
+ if (allowsCollapse || allowsOrdered) {
+ auto [count, reason]{GetArgumentValueWithReason(
+ beginSpec, llvm::omp::Clause::OMPC_collapse, version)};
+ auto [vo, ro]{GetArgumentValueWithReason(
+ beginSpec, llvm::omp::Clause::OMPC_ordered, version)};
+ if (vo) {
+ if (!count || *count < *vo) {
+ count = vo;
+ reason = ro;
+ }
+ }
+ return {count, true, reason};
+ }
+
+ if (IsLoopTransforming(dir)) {
+ switch (dir) {
+ case llvm::omp::Directive::OMPD_interchange: {
+ // Get the length of the argument list to PERMUTATION.
+ auto [num, reason]{GetNumArgumentsWithReason(
+ beginSpec, llvm::omp::Clause::OMPC_permutation, version)};
+ return {num, true, reason};
+ }
+ case llvm::omp::Directive::OMPD_stripe:
+ case llvm::omp::Directive::OMPD_tile: {
+ // Get the length of the argument list to SIZES.
+ auto [num, reason]{GetNumArgumentsWithReason(
+ beginSpec, llvm::omp::Clause::OMPC_sizes, version)};
+ return {num, true, reason};
+ return {std::nullopt, true, Reason()};
+ }
+ case llvm::omp::Directive::OMPD_fuse: {
+ // Get the value from the argument to DEPTH.
+ if (parser::omp::FindClause(beginSpec, llvm::omp::Clause::OMPC_depth)) {
+ auto [count, reason]{GetArgumentValueWithReason(
+ beginSpec, llvm::omp::Clause::OMPC_depth, version)};
+ return {count, true, reason};
+ }
+ std::string name{
+ parser::omp::GetUpperName(llvm::omp::Clause::OMPC_depth, version)};
+ Reason reason(
+ format("%s clause was not specified, a value of 1 was assumed",
+ name.c_str()),
+ beginSpec.source);
+ return {1, true, reason};
+ }
+ case llvm::omp::Directive::OMPD_reverse:
+ case llvm::omp::Directive::OMPD_unroll:
+ return {1, false, Reason()};
+ // TODO: case llvm::omp::Directive::OMPD_flatten:
+ // TODO: case llvm::omp::Directive::OMPD_split:
+ default:
+ break;
+ }
+ }
+
+ return {std::nullopt, false, Reason()};
+}
+
+// Return the range of the affected nests in the sequence:
+// {first, count, reason}.
+std::tuple<std::optional<int64_t>, std::optional<int64_t>, Reason>
+GetAffectedLoopRangeWithReason(
+ const parser::OpenMPLoopConstruct &x, unsigned version) {
+ const parser::OmpDirectiveSpecification &beginSpec{x.BeginDir()};
+ llvm::omp::Directive dir{beginSpec.DirId()};
+
+ if (dir == llvm::omp::Directive::OMPD_fuse) {
+ std::string name{parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(
+ llvm::omp::Clause::OMPC_looprange, version))};
+ if (auto *clause{parser::omp::FindClause(
+ beginSpec, llvm::omp::Clause::OMPC_looprange)}) {
+ auto &range{DEREF(parser::Unwrap<parser::OmpLooprangeClause>(clause->u))};
+ std::optional<int64_t> first{GetIntValue(std::get<0>(range.t))};
+ std::optional<int64_t> count{GetIntValue(std::get<1>(range.t))};
+ if (!first || !count || *first <= 0 || *count <= 0) {
+ return {std::nullopt, std::nullopt, Reason()};
+ }
+ std::string name{parser::omp::GetUpperName(
+ llvm::omp::Clause::OMPC_looprange, version)};
+ Reason reason(
+ format("%s clause was specified with a count of %ld starting at loop "
+ "%ld",
+ name.c_str(), *count, *first),
+ clause->source);
+ return {*first, *count, reason};
+ }
+ // If LOOPRANGE was not found, return {1, -1}, where -1 means "the whole
+ // associated sequence".
+ Reason reason(
+ "%s clause was not specified, a value of 1 was assumed", name.c_str());
+ return {1, -1, reason};
+ }
+
+ assert(llvm::omp::getDirectiveAssociation(dir) ==
+ llvm::omp::Association::LoopNest &&
+ "Expecting loop-nest-associated construct");
+ // For loop-nest constructs, a single loop-nest is affected.
+ return {1, 1, Reason()};
+}
+
+std::optional<int64_t> GetRequiredCount(
+ std::optional<int64_t> first, std::optional<int64_t> count) {
+ if (first && count && *first > 0) {
+ if (*count > 0) {
+ return *first + *count - 1;
+ } else if (*count == -1) {
+ return -1;
+ }
+ }
+ return std::nullopt;
+}
+
+LoopSequence::LoopSequence(const parser::ExecutionPartConstruct &root,
+ unsigned version, bool allowAllLoops)
+ : version_(version), allowAllLoops_(allowAllLoops) {
entry_ = createConstructEntry(root);
assert(entry_ && "Expecting loop like code");
@@ -780,8 +939,10 @@ LoopSequence::LoopSequence(
calculateEverything();
}
-LoopSequence::LoopSequence(std::unique_ptr<Construct> entry, bool allowAllLoops)
- : allowAllLoops_(allowAllLoops), entry_(std::move(entry)) {
+LoopSequence::LoopSequence(
+ std::unique_ptr<Construct> entry, unsigned version, bool allowAllLoops)
+ : version_(version), allowAllLoops_(allowAllLoops),
+ entry_(std::move(entry)) {
createChildrenFromRange(entry_->location);
calculateEverything();
}
@@ -811,7 +972,8 @@ void LoopSequence::createChildrenFromRange(
// case any code between consecutive children must be "transparent".
for (auto &code : BlockRange(begin, end, BlockRange::Step::Over)) {
if (auto entry{createConstructEntry(code)}) {
- children_.push_back(LoopSequence(std::move(entry), allowAllLoops_));
+ children_.push_back(
+ LoopSequence(std::move(entry), version_, allowAllLoops_));
if (!IsTransformableLoop(code)) {
hasInvalidIC_ = true;
hasOpaqueIC_ = true;
@@ -950,10 +1112,21 @@ LoopSequence::Depth LoopSequence::calculateDepths() const {
case llvm::omp::Directive::OMPD_fuse:
if (auto *clause{parser::omp::FindClause(
beginSpec, llvm::omp::Clause::OMPC_depth)}) {
- // FIXME: all loops must be fused for this
auto &expr{parser::UnwrapRef<parser::Expr>(clause->u)};
auto value{GetIntValue(expr)};
- return Depth{value, value};
+ auto nestedLength{getNestedLength()};
+ // The result is a perfect nest only if all loop in the sequence
+ // are fused.
+ if (value && nestedLength) {
+ auto [first, count, _]{GetAffectedLoopRangeWithReason(omp, version_)};
+ if (auto required{GetRequiredCount(first, count)}) {
+ if (*required == -1 || *required == *nestedLength) {
+ return Depth{value, value};
+ }
+ return Depth{1, 1};
+ }
+ }
+ return Depth{std::nullopt, std::nullopt};
}
return Depth{1, 1};
case llvm::omp::Directive::OMPD_interchange:
``````````
</details>
https://github.com/llvm/llvm-project/pull/185299
More information about the llvm-branch-commits
mailing list