[flang-commits] [flang] e442173 - [flang][OpenMP] Implement iterator that flattens BLOCK constructs (#180981)
via flang-commits
flang-commits at lists.llvm.org
Mon Feb 16 05:45:38 PST 2026
Author: Krzysztof Parzyszek
Date: 2026-02-16T07:45:34-06:00
New Revision: e442173c8bcd138a5193890954859e4809b2d4ea
URL: https://github.com/llvm/llvm-project/commit/e442173c8bcd138a5193890954859e4809b2d4ea
DIFF: https://github.com/llvm/llvm-project/commit/e442173c8bcd138a5193890954859e4809b2d4ea.diff
LOG: [flang][OpenMP] Implement iterator that flattens BLOCK constructs (#180981)
In OpenMP a canonical loop nest may be enclosed in a BLOCK construct.
Specifically, the two loops below are considered to form a valid loop
sequence:
```f90
do i = 1, n
end do
block
do j = 1, m
end do
end block
```
Implement an extension to parser::Block::iterator that will treat the
example above as
```f90
do i = 1, n
end do
do j = 1, m
end do
```
that is, as if the BLOCK/ENDBLOCK statement were deleted. This will make
the analysis of loop nests easier, since any such code will not have to
deal with BLOCK constructs itself.
Added:
Modified:
flang/include/flang/Parser/openmp-utils.h
flang/lib/Parser/openmp-utils.cpp
flang/lib/Semantics/check-omp-loop.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
index bd200558e4c59..65a9890d85293 100644
--- a/flang/include/flang/Parser/openmp-utils.h
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -16,9 +16,11 @@
#include "flang/Common/indirection.h"
#include "flang/Common/template.h"
#include "flang/Parser/parse-tree.h"
+#include "llvm/ADT/iterator_range.h"
#include "llvm/Frontend/OpenMP/OMP.h"
#include <cassert>
+#include <iterator>
#include <tuple>
#include <type_traits>
#include <utility>
@@ -237,82 +239,188 @@ struct OmpAllocateInfo {
OmpAllocateInfo SplitOmpAllocate(const OmpAllocateDirective &x);
-namespace detail {
-template <bool IsConst, typename T> struct ConstIf {
- using type = std::conditional_t<IsConst, std::add_const_t<T>, T>;
-};
+// Iterate over a range of parser::Block::const_iterator's. When the end
+// of the range is reached, the iterator becomes invalid.
+// Treat BLOCK constructs as if they were transparent, i.e. as if the
+// BLOCK/ENDBLOCK statements, and the specification part contained within
+// were removed. The stepping determines whether the iterator steps "into"
+// DO loops and OpenMP loop constructs, or steps "over" them.
+//
+// Example: consecutive locations of the iterator:
+//
+// Step::Into Step::Over
+// block block
+// 1 => stmt1 1 => stmt1
+// block block
+// integer :: x integer :: x
+// 2 => stmt2 2 => stmt2
+// block block
+// end block end block
+// end block end block
+// 3 => do i = 1, n 3 => do i = 1, n
+// 4 => continue continue
+// end do end do
+// 5 => stmt3 4 => stmt3
+// end block end block
+//
+// 6 => <invalid> 5 => <invalid>
+//
+// The iterator is in a legal state (position) if it's at an
+// ExecutionPartConstruct that is not a BlockConstruct, or is invalid.
+struct ExecutionPartIterator {
+ enum class Step {
+ Into,
+ Over,
+ Default = Into,
+ };
+
+ using IteratorType = Block::const_iterator;
+ using IteratorRange = llvm::iterator_range<IteratorType>;
+
+ struct Construct {
+ Construct(IteratorType b, IteratorType e, const ExecutionPartConstruct *c)
+ : range(b, e), owner(c) {}
+ template <typename R>
+ Construct(const R &r, const ExecutionPartConstruct *c)
+ : range(r), owner(c) {}
+ Construct(const Construct &c) = default;
+ IteratorRange range;
+ const ExecutionPartConstruct *owner;
+ };
+
+ ExecutionPartIterator() = default;
+
+ ExecutionPartIterator(IteratorType b, IteratorType e, Step s = Step::Default,
+ const ExecutionPartConstruct *c = nullptr)
+ : stepping_(s) {
+ stack_.emplace_back(b, e, c);
+ adjust();
+ }
+ template <typename R, //
+ typename = decltype(std::declval<R>().begin()),
+ typename = decltype(std::declval<R>().end())>
+ ExecutionPartIterator(const R &range, Step stepping = Step::Default,
+ const ExecutionPartConstruct *construct = nullptr)
+ : ExecutionPartIterator(range.begin(), range.end(), stepping, construct) {
+ }
-template <bool IsConst, typename T>
-using ConstIfT = typename ConstIf<IsConst, T>::type;
-} // namespace detail
+ // Advance the iterator to the next legal position. If the current position
+ // is a DO-loop or a loop construct, step into the contained Block.
+ void step();
-template <bool IsConst> struct LoopRange {
- using QualBlock = detail::ConstIfT<IsConst, Block>;
- using QualReference = decltype(std::declval<QualBlock>().front());
- using QualPointer = std::remove_reference_t<QualReference> *;
+ // Advance the iterator to the next legal position. If the current position
+ // is a DO-loop or a loop construct, step to the next legal position following
+ // the DO-loop or loop construct.
+ void next();
- LoopRange(QualBlock &x) { Initialize(x); }
- LoopRange(QualReference x);
+ bool valid() const { return !stack_.empty(); }
- LoopRange(detail::ConstIfT<IsConst, OpenMPLoopConstruct> &x)
- : LoopRange(std::get<Block>(x.t)) {}
- LoopRange(detail::ConstIfT<IsConst, DoConstruct> &x)
- : LoopRange(std::get<Block>(x.t)) {}
+ decltype(auto) operator*() const { return *at(); }
+ bool operator==(const ExecutionPartIterator &other) const {
+ if (valid() != other.valid()) {
+ return false;
+ }
+ // Invalid iterators are considered equal.
+ return !valid() ||
+ stack_.back().range.begin() == other.stack_.back().range.begin();
+ }
+ bool operator!=(const ExecutionPartIterator &other) const {
+ return !(*this == other);
+ }
- size_t size() const { return items.size(); }
- bool empty() const { return items.size() == 0; }
+ ExecutionPartIterator &operator++() {
+ if (stepping_ == Step::Into) {
+ step();
+ } else {
+ assert(stepping_ == Step::Over && "Unexpected stepping");
+ next();
+ }
+ return *this;
+ }
- struct iterator;
+ ExecutionPartIterator operator++(int) {
+ ExecutionPartIterator copy{*this};
+ operator++();
+ return copy;
+ }
- iterator begin();
- iterator end();
+ using
diff erence_type = IteratorType::
diff erence_type;
+ using value_type = IteratorType::value_type;
+ using reference = IteratorType::reference;
+ using pointer = IteratorType::pointer;
+ using iterator_category = std::forward_iterator_tag;
private:
- void Initialize(QualBlock &body);
+ IteratorType at() const { return stack_.back().range.begin(); };
- std::vector<QualPointer> items;
+ // If the iterator is not at a legal location, keep advancing it until
+ // it lands at a legal location or becomes invalid.
+ void adjust();
+
+ const Step stepping_ = Step::Default;
+ std::vector<Construct> stack_;
};
-template <typename T> LoopRange(T &x) -> LoopRange<std::is_const_v<T>>;
+template <typename Iterator = ExecutionPartIterator> struct ExecutionPartRange {
+ using Step = typename Iterator::Step;
-template <bool IsConst> struct LoopRange<IsConst>::iterator {
- QualReference operator*() { return **at; }
+ ExecutionPartRange(Block::const_iterator begin, Block::const_iterator end,
+ Step stepping = Step::Default,
+ const ExecutionPartConstruct *owner = nullptr)
+ : begin_(begin, end, stepping, owner), end_() {}
+ template <typename R, //
+ typename = decltype(std::declval<R>().begin()),
+ typename = decltype(std::declval<R>().end())>
+ ExecutionPartRange(const R &range, Step stepping = Step::Default,
+ const ExecutionPartConstruct *owner = nullptr)
+ : ExecutionPartRange(range.begin(), range.end(), stepping, owner) {}
- bool operator==(const iterator &other) const { return at == other.at; }
- bool operator!=(const iterator &other) const { return at != other.at; }
+ Iterator begin() const { return begin_; }
+ Iterator end() const { return end_; }
- iterator &operator++() {
- ++at;
- return *this;
+private:
+ Iterator begin_, end_;
+};
+
+struct LoopNestIterator : public ExecutionPartIterator {
+ LoopNestIterator() = default;
+
+ LoopNestIterator(IteratorType b, IteratorType e, Step s = Step::Default,
+ const ExecutionPartConstruct *c = nullptr)
+ : ExecutionPartIterator(b, e, s, c) {
+ adjust();
}
- iterator &operator--() {
- --at;
+ template <typename R, //
+ typename = decltype(std::declval<R>().begin()),
+ typename = decltype(std::declval<R>().end())>
+ LoopNestIterator(const R &range, Step stepping = Step::Default,
+ const ExecutionPartConstruct *construct = nullptr)
+ : LoopNestIterator(range.begin(), range.end(), stepping, construct) {}
+
+ LoopNestIterator &operator++() {
+ ExecutionPartIterator::operator++();
+ adjust();
return *this;
}
- iterator operator++(int);
- iterator operator--(int);
-
-private:
- friend struct LoopRange;
- typename decltype(LoopRange::items)::iterator at;
-};
-template <bool IsConst> inline auto LoopRange<IsConst>::begin() -> iterator {
- iterator x;
- x.at = items.begin();
- return x;
-}
+ LoopNestIterator operator++(int) {
+ LoopNestIterator copy{*this};
+ operator++();
+ return copy;
+ }
-template <bool IsConst> inline auto LoopRange<IsConst>::end() -> iterator {
- iterator x;
- x.at = items.end();
- return x;
-}
+private:
+ static bool isLoop(const ExecutionPartConstruct &c);
-using ConstLoopRange = LoopRange<true>;
+ void adjust() {
+ while (valid() && !isLoop(**this)) {
+ ExecutionPartIterator::operator++();
+ }
+ }
+};
-extern template struct LoopRange<true>;
-extern template struct LoopRange<false>;
+using BlockRange = ExecutionPartRange<ExecutionPartIterator>;
+using LoopRange = ExecutionPartRange<LoopNestIterator>;
} // namespace Fortran::parser::omp
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
index a9dbb55819b1e..9aa7014ccb057 100644
--- a/flang/lib/Parser/openmp-utils.cpp
+++ b/flang/lib/Parser/openmp-utils.cpp
@@ -211,53 +211,54 @@ OmpAllocateInfo SplitOmpAllocate(const OmpAllocateDirective &x) {
return info;
}
-template <bool IsConst> LoopRange<IsConst>::LoopRange(QualReference x) {
- if (auto *doLoop{Unwrap<DoConstruct>(x)}) {
- Initialize(std::get<Block>(doLoop->t));
- } else if (auto *omp{Unwrap<OpenMPLoopConstruct>(x)}) {
- Initialize(std::get<Block>(omp->t));
+void ExecutionPartIterator::step() {
+ // Advance the iterator to the next legal position. If the current
+ // position is a DO-loop or a loop construct, step into it.
+ if (valid()) {
+ IteratorType where{at()};
+ if (auto *loop{GetOmpLoop(*where)}) {
+ stack_.emplace_back(std::get<Block>(loop->t), &*where);
+ } else if (auto *loop{GetDoConstruct(*where)}) {
+ stack_.emplace_back(std::get<Block>(loop->t), &*where);
+ } else {
+ stack_.back().range =
+ IteratorRange(std::next(where), stack_.back().range.end());
+ }
+ adjust();
}
}
-template <bool IsConst> void LoopRange<IsConst>::Initialize(QualBlock &body) {
- using QualIterator = decltype(std::declval<QualBlock>().begin());
- auto makeRange{[](auto &container) {
- return llvm::make_range(container.begin(), container.end());
- }};
+void ExecutionPartIterator::next() {
+ // Advance the iterator to the next legal position. If the current
+ // position is a DO-loop or a loop construct, step over it.
+ if (valid()) {
+ stack_.back().range =
+ IteratorRange(std::next(at()), stack_.back().range.end());
+ adjust();
+ }
+}
- std::vector<llvm::iterator_range<QualIterator>> nest{makeRange(body)};
- do {
- auto at{nest.back().begin()};
- auto end{nest.back().end()};
- nest.pop_back();
- while (at != end) {
- if (auto *block{Unwrap<BlockConstruct>(*at)}) {
- nest.push_back(llvm::make_range(std::next(at), end));
- nest.push_back(makeRange(std::get<Block>(block->t)));
- break;
- } else if (Unwrap<DoConstruct>(*at) || Unwrap<OpenMPLoopConstruct>(*at)) {
- items.push_back(&*at);
+void ExecutionPartIterator::adjust() {
+ // If the iterator is not at a legal location, keep advancing it until
+ // it lands at a legal location or becomes invalid.
+ while (valid()) {
+ if (stack_.back().range.empty()) {
+ stack_.pop_back();
+ if (valid()) {
+ stack_.back().range =
+ IteratorRange(std::next(at()), stack_.back().range.end());
}
- ++at;
+ } else if (auto *block{GetFortranBlockConstruct(*at())}) {
+ stack_.emplace_back(std::get<Block>(block->t), &*at());
+ } else {
+ break;
}
- } while (!nest.empty());
-}
-
-template <bool IsConst>
-auto LoopRange<IsConst>::iterator::operator++(int) -> iterator {
- auto old = *this;
- ++*this;
- return old;
+ }
}
-template <bool IsConst>
-auto LoopRange<IsConst>::iterator::operator--(int) -> iterator {
- auto old = *this;
- --*this;
- return old;
+bool LoopNestIterator::isLoop(const ExecutionPartConstruct &c) {
+ return Unwrap<OpenMPLoopConstruct>(c) != nullptr ||
+ Unwrap<DoConstruct>(c) != nullptr;
}
-template struct LoopRange<false>;
-template struct LoopRange<true>;
-
} // namespace Fortran::parser::omp
diff --git a/flang/lib/Semantics/check-omp-loop.cpp b/flang/lib/Semantics/check-omp-loop.cpp
index 20c52ba3417ad..02580e16a8bf4 100644
--- a/flang/lib/Semantics/check-omp-loop.cpp
+++ b/flang/lib/Semantics/check-omp-loop.cpp
@@ -272,7 +272,8 @@ static bool IsLoopTransforming(llvm::omp::Directive dir) {
void OmpStructureChecker::CheckNestedBlock(
const parser::OpenMPLoopConstruct &x, const parser::Block &body) {
- for (auto &stmt : body) {
+ using BlockRange = parser::omp::BlockRange;
+ for (auto &stmt : BlockRange(body, BlockRange::Step::Over)) {
if (auto *dir{parser::Unwrap<parser::CompilerDirective>(stmt)}) {
context_.Say(dir->source,
"Compiler directives are not allowed inside OpenMP loop constructs"_warn_en_US);
@@ -281,8 +282,6 @@ void OmpStructureChecker::CheckNestedBlock(
context_.Say(omp->source,
"Only loop-transforming OpenMP constructs are allowed inside OpenMP loop constructs"_err_en_US);
}
- } else if (auto *block{parser::Unwrap<parser::BlockConstruct>(stmt)}) {
- CheckNestedBlock(x, std::get<parser::Block>(block->t));
} else if (!parser::Unwrap<parser::DoConstruct>(stmt)) {
parser::CharBlock source{parser::GetSource(stmt).value_or(x.source)};
context_.Say(source,
@@ -348,8 +347,9 @@ static std::optional<size_t> CountGeneratedNests(const parser::Block &block) {
// messages about a potentially incorrect loop count.
// In such cases reset the count to nullopt. Once it becomes nullopt,
// keep it that way.
+ using LoopRange = parser::omp::LoopRange;
std::optional<size_t> numLoops{0};
- for (auto &epc : parser::omp::LoopRange(block)) {
+ for (auto &epc : LoopRange(block, LoopRange::Step::Over)) {
if (auto genCount{CountGeneratedNests(epc)}) {
*numLoops += *genCount;
} else {
More information about the flang-commits
mailing list