[flang-commits] [flang] [flang][OpenMP] Implement iterator that flattens BLOCK constructs (PR #180981)

Krzysztof Parzyszek via flang-commits flang-commits at lists.llvm.org
Wed Feb 11 10:26:06 PST 2026


https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/180981

>From 0e4e8ab05da4177c2c7b89bee1f2290942398634 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Mon, 9 Feb 2026 07:30:35 -0600
Subject: [PATCH 1/3] [flang][OpenMP] Implement iterator that flattens BLOCK
 constructs

In OpenMP a canonical loop nest may be enclosed in a BLOCK construct.
Specifically, the two loops below are considered to form a valid loop
sequence:
  do i = 1, n
  end do
  block
    do j = 1, m
    end do
  end block

Implement an extension to parser::Block::iterator that will treat the
example above as
  do i = 1, n
  end do
  do j = 1, m
  end do
that is, as if the BLOCK/ENDBLOCK statement were deleted. This will
make the analysis of loop nests easier, since any such code will not
have to deal with BLOCK constructs itself.
---
 flang/include/flang/Parser/openmp-utils.h | 182 +++++++++++++++-------
 flang/lib/Parser/openmp-utils.cpp         |  78 +++++-----
 flang/lib/Semantics/check-omp-loop.cpp    |   8 +-
 3 files changed, 167 insertions(+), 101 deletions(-)

diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
index bd200558e4c59..3d42d9943f76b 100644
--- a/flang/include/flang/Parser/openmp-utils.h
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -16,6 +16,7 @@
 #include "flang/Common/indirection.h"
 #include "flang/Common/template.h"
 #include "flang/Parser/parse-tree.h"
+#include "llvm/ADT/iterator_range.h"
 #include "llvm/Frontend/OpenMP/OMP.h"
 
 #include <cassert>
@@ -237,82 +238,151 @@ struct OmpAllocateInfo {
 
 OmpAllocateInfo SplitOmpAllocate(const OmpAllocateDirective &x);
 
-namespace detail {
-template <bool IsConst, typename T> struct ConstIf {
-  using type = std::conditional_t<IsConst, std::add_const_t<T>, T>;
-};
-
-template <bool IsConst, typename T>
-using ConstIfT = typename ConstIf<IsConst, T>::type;
-} // namespace detail
-
-template <bool IsConst> struct LoopRange {
-  using QualBlock = detail::ConstIfT<IsConst, Block>;
-  using QualReference = decltype(std::declval<QualBlock>().front());
-  using QualPointer = std::remove_reference_t<QualReference> *;
+// Iterate over a range of parser::Block::const_iterator's. When the end
+// of the range is reached, the iterator becomes invalid.
+// Treat BLOCK constructs as if they were transparent, i.e. as if the
+// BLOCK/ENDBLOCK statements, and the specification part contained within
+// were removed. For example, given the range:
+//   stmt1
+//   block
+//     integer :: x
+//     stmt2
+//     block
+//     end block
+//   end block
+//   stmt3
+// the iterator will return stmt1, stmt2, stmt3 in that order, then will
+// become invalid.
+//
+// The iterator is in a legal state (position) if it's at an
+// ExecutionPartConstruct that is not a BlockConstruct, or is invalid.
+struct ExecutionPartIterator {
+  enum class Step {
+    Into,
+    Over,
+    Default = Into,
+  };
+
+  using IteratorType = parser::Block::const_iterator;
+  using IteratorRange = llvm::iterator_range<IteratorType>;
+
+  struct Construct {
+    Construct(IteratorType b, IteratorType e,
+        const parser::ExecutionPartConstruct *c = nullptr)
+        : range(b, e), owner(c) {}
+    template <typename C>
+    Construct(C &&r, const parser::ExecutionPartConstruct *c = nullptr)
+        : range(r), owner(c) {}
+    IteratorRange range;
+    const parser::ExecutionPartConstruct *owner = nullptr;
+  };
+
+  ExecutionPartIterator() = default;
+
+  ExecutionPartIterator(IteratorType b, IteratorType e, Step s = Step::Default,
+      const parser::ExecutionPartConstruct *c = nullptr)
+      : stepping_(s) {
+    stack_.emplace_back(b, e, c);
+    adjust();
+  }
+  template <typename C>
+  ExecutionPartIterator(C &&range, Step stepping = Step::Default,
+      const parser::ExecutionPartConstruct *construct = nullptr)
+      : ExecutionPartIterator(range.begin(), range.end(), stepping, construct) {
+  }
 
-  LoopRange(QualBlock &x) { Initialize(x); }
-  LoopRange(QualReference x);
+  // Advance the iterator to the next legal position. If the current position
+  // is a DO-loop or a loop construct, step into the contained Block.
+  void step();
 
-  LoopRange(detail::ConstIfT<IsConst, OpenMPLoopConstruct> &x)
-      : LoopRange(std::get<Block>(x.t)) {}
-  LoopRange(detail::ConstIfT<IsConst, DoConstruct> &x)
-      : LoopRange(std::get<Block>(x.t)) {}
+  // Advance the iterator to the next legal position. If the current position
+  // is a DO-loop or a loop construct, step to the next legal position following
+  // the DO-loop or loop construct.
+  void next();
 
-  size_t size() const { return items.size(); }
-  bool empty() const { return items.size() == 0; }
+  bool valid() const { return !stack_.empty(); }
 
-  struct iterator;
+  decltype(auto) operator*() const { return *at(); }
+  bool operator==(const ExecutionPartIterator &other) const {
+    if (valid() != other.valid()) {
+      return false;
+    }
+    // Invalid iterators are considered equal.
+    return !valid() ||
+        stack_.back().range.begin() == other.stack_.back().range.begin();
+  }
+  bool operator!=(const ExecutionPartIterator &other) const {
+    return !(*this == other);
+  }
 
-  iterator begin();
-  iterator end();
+  ExecutionPartIterator &operator++() {
+    if (stepping_ == Step::Into) {
+      step();
+    } else {
+      assert(stepping_ == Step::Over && "Unexpected stepping");
+      next();
+    }
+    return *this;
+  }
 
 private:
-  void Initialize(QualBlock &body);
+  IteratorType at() const { return stack_.back().range.begin(); };
 
-  std::vector<QualPointer> items;
+  // If the iterator is not at a legal location, keep advancing it until
+  // it lands at a legal location or becomes invalid.
+  void adjust();
+
+  const Step stepping_ = Step::Default;
+  std::vector<Construct> stack_;
 };
 
-template <typename T> LoopRange(T &x) -> LoopRange<std::is_const_v<T>>;
+template <typename Iterator = ExecutionPartIterator> struct ExecutionPartRange {
+  using Step = typename Iterator::Step;
 
-template <bool IsConst> struct LoopRange<IsConst>::iterator {
-  QualReference operator*() { return **at; }
+  ExecutionPartRange(parser::Block::const_iterator begin,
+      parser::Block::const_iterator end, Step stepping = Step::Default,
+      const parser::ExecutionPartConstruct *owner = nullptr)
+      : begin_(begin, end, stepping, owner), end_() {}
+  ExecutionPartRange(const parser::Block &range, Step stepping = Step::Default,
+      const parser::ExecutionPartConstruct *owner = nullptr)
+      : ExecutionPartRange(range.begin(), range.end(), stepping, owner) {}
 
-  bool operator==(const iterator &other) const { return at == other.at; }
-  bool operator!=(const iterator &other) const { return at != other.at; }
+  Iterator begin() const { return begin_; }
+  Iterator end() const { return end_; }
 
-  iterator &operator++() {
-    ++at;
-    return *this;
+private:
+  Iterator begin_, end_;
+};
+
+struct LoopNestIterator : public ExecutionPartIterator {
+  LoopNestIterator() = default;
+  LoopNestIterator(IteratorType b, IteratorType e, Step s = Step::Default,
+      const parser::ExecutionPartConstruct *c = nullptr)
+      : ExecutionPartIterator(b, e, s, c) {
+    adjust();
   }
-  iterator &operator--() {
-    --at;
+
+  LoopNestIterator &operator++() {
+    ExecutionPartIterator::operator++();
+    adjust();
     return *this;
   }
-  iterator operator++(int);
-  iterator operator--(int);
 
 private:
-  friend struct LoopRange;
-  typename decltype(LoopRange::items)::iterator at;
-};
-
-template <bool IsConst> inline auto LoopRange<IsConst>::begin() -> iterator {
-  iterator x;
-  x.at = items.begin();
-  return x;
-}
-
-template <bool IsConst> inline auto LoopRange<IsConst>::end() -> iterator {
-  iterator x;
-  x.at = items.end();
-  return x;
-}
+  static bool isLoop(const parser::ExecutionPartConstruct &c) {
+    return parser::Unwrap<parser::OpenMPLoopConstruct>(c) != nullptr ||
+        parser::Unwrap<parser::DoConstruct>(c) != nullptr;
+  }
 
-using ConstLoopRange = LoopRange<true>;
+  void adjust() {
+    while (valid() && !isLoop(**this)) {
+      ExecutionPartIterator::operator++();
+    }
+  }
+};
 
-extern template struct LoopRange<true>;
-extern template struct LoopRange<false>;
+using BlockRange = ExecutionPartRange<ExecutionPartIterator>;
+using LoopRange = ExecutionPartRange<LoopNestIterator>;
 
 } // namespace Fortran::parser::omp
 
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
index a9dbb55819b1e..fe495ac957eec 100644
--- a/flang/lib/Parser/openmp-utils.cpp
+++ b/flang/lib/Parser/openmp-utils.cpp
@@ -211,53 +211,49 @@ OmpAllocateInfo SplitOmpAllocate(const OmpAllocateDirective &x) {
   return info;
 }
 
-template <bool IsConst> LoopRange<IsConst>::LoopRange(QualReference x) {
-  if (auto *doLoop{Unwrap<DoConstruct>(x)}) {
-    Initialize(std::get<Block>(doLoop->t));
-  } else if (auto *omp{Unwrap<OpenMPLoopConstruct>(x)}) {
-    Initialize(std::get<Block>(omp->t));
+void ExecutionPartIterator::step() {
+  // Advance the iterator to the next legal position. If the current
+  // position is a DO-loop or a loop construct, step into it.
+  if (valid()) {
+    IteratorType where{at()};
+    if (auto *loop{parser::omp::GetOmpLoop(*where)}) {
+      stack_.emplace_back(std::get<parser::Block>(loop->t), &*where);
+    } else if (auto *loop{parser::omp::GetDoConstruct(*where)}) {
+      stack_.emplace_back(std::get<parser::Block>(loop->t), &*where);
+    } else {
+      stack_.back().range =
+          IteratorRange(std::next(where), stack_.back().range.end());
+    }
+    adjust();
   }
 }
 
-template <bool IsConst> void LoopRange<IsConst>::Initialize(QualBlock &body) {
-  using QualIterator = decltype(std::declval<QualBlock>().begin());
-  auto makeRange{[](auto &container) {
-    return llvm::make_range(container.begin(), container.end());
-  }};
+void ExecutionPartIterator::next() {
+  // Advance the iterator to the next legal position. If the current
+  // position is a DO-loop or a loop construct, step over it.
+  if (valid()) {
+    stack_.back().range =
+        IteratorRange(std::next(at()), stack_.back().range.end());
+    adjust();
+  }
+}
 
-  std::vector<llvm::iterator_range<QualIterator>> nest{makeRange(body)};
-  do {
-    auto at{nest.back().begin()};
-    auto end{nest.back().end()};
-    nest.pop_back();
-    while (at != end) {
-      if (auto *block{Unwrap<BlockConstruct>(*at)}) {
-        nest.push_back(llvm::make_range(std::next(at), end));
-        nest.push_back(makeRange(std::get<Block>(block->t)));
-        break;
-      } else if (Unwrap<DoConstruct>(*at) || Unwrap<OpenMPLoopConstruct>(*at)) {
-        items.push_back(&*at);
+void ExecutionPartIterator::adjust() {
+  // If the iterator is not at a legal location, keep advancing it until
+  // it lands at a legal location or becomes invalid.
+  while (valid()) {
+    if (stack_.back().range.empty()) {
+      stack_.pop_back();
+      if (valid()) {
+        stack_.back().range =
+            IteratorRange(std::next(at()), stack_.back().range.end());
       }
-      ++at;
+    } else if (auto *block{parser::omp::GetFortranBlockConstruct(*at())}) {
+      stack_.emplace_back(std::get<parser::Block>(block->t), &*at());
+    } else {
+      break;
     }
-  } while (!nest.empty());
-}
-
-template <bool IsConst>
-auto LoopRange<IsConst>::iterator::operator++(int) -> iterator {
-  auto old = *this;
-  ++*this;
-  return old;
-}
-
-template <bool IsConst>
-auto LoopRange<IsConst>::iterator::operator--(int) -> iterator {
-  auto old = *this;
-  --*this;
-  return old;
+  }
 }
 
-template struct LoopRange<false>;
-template struct LoopRange<true>;
-
 } // namespace Fortran::parser::omp
diff --git a/flang/lib/Semantics/check-omp-loop.cpp b/flang/lib/Semantics/check-omp-loop.cpp
index 20c52ba3417ad..02580e16a8bf4 100644
--- a/flang/lib/Semantics/check-omp-loop.cpp
+++ b/flang/lib/Semantics/check-omp-loop.cpp
@@ -272,7 +272,8 @@ static bool IsLoopTransforming(llvm::omp::Directive dir) {
 
 void OmpStructureChecker::CheckNestedBlock(
     const parser::OpenMPLoopConstruct &x, const parser::Block &body) {
-  for (auto &stmt : body) {
+  using BlockRange = parser::omp::BlockRange;
+  for (auto &stmt : BlockRange(body, BlockRange::Step::Over)) {
     if (auto *dir{parser::Unwrap<parser::CompilerDirective>(stmt)}) {
       context_.Say(dir->source,
           "Compiler directives are not allowed inside OpenMP loop constructs"_warn_en_US);
@@ -281,8 +282,6 @@ void OmpStructureChecker::CheckNestedBlock(
         context_.Say(omp->source,
             "Only loop-transforming OpenMP constructs are allowed inside OpenMP loop constructs"_err_en_US);
       }
-    } else if (auto *block{parser::Unwrap<parser::BlockConstruct>(stmt)}) {
-      CheckNestedBlock(x, std::get<parser::Block>(block->t));
     } else if (!parser::Unwrap<parser::DoConstruct>(stmt)) {
       parser::CharBlock source{parser::GetSource(stmt).value_or(x.source)};
       context_.Say(source,
@@ -348,8 +347,9 @@ static std::optional<size_t> CountGeneratedNests(const parser::Block &block) {
   // messages about a potentially incorrect loop count.
   // In such cases reset the count to nullopt. Once it becomes nullopt,
   // keep it that way.
+  using LoopRange = parser::omp::LoopRange;
   std::optional<size_t> numLoops{0};
-  for (auto &epc : parser::omp::LoopRange(block)) {
+  for (auto &epc : LoopRange(block, LoopRange::Step::Over)) {
     if (auto genCount{CountGeneratedNests(epc)}) {
       *numLoops += *genCount;
     } else {

>From f521f21cbac88d5ce4749e08a82191aa24335a83 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Wed, 11 Feb 2026 12:19:52 -0600
Subject: [PATCH 2/3] Move calls to Unwrap to .cpp file

---
 flang/include/flang/Parser/openmp-utils.h | 5 +----
 flang/lib/Parser/openmp-utils.cpp         | 5 +++++
 2 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
index 3d42d9943f76b..1047cde5e3a5b 100644
--- a/flang/include/flang/Parser/openmp-utils.h
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -369,10 +369,7 @@ struct LoopNestIterator : public ExecutionPartIterator {
   }
 
 private:
-  static bool isLoop(const parser::ExecutionPartConstruct &c) {
-    return parser::Unwrap<parser::OpenMPLoopConstruct>(c) != nullptr ||
-        parser::Unwrap<parser::DoConstruct>(c) != nullptr;
-  }
+  static bool isLoop(const parser::ExecutionPartConstruct &c);
 
   void adjust() {
     while (valid() && !isLoop(**this)) {
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
index fe495ac957eec..70482896b83af 100644
--- a/flang/lib/Parser/openmp-utils.cpp
+++ b/flang/lib/Parser/openmp-utils.cpp
@@ -256,4 +256,9 @@ void ExecutionPartIterator::adjust() {
   }
 }
 
+bool LoopNestIterator::isLoop(const parser::ExecutionPartConstruct &c) {
+  return parser::Unwrap<parser::OpenMPLoopConstruct>(c) != nullptr ||
+      parser::Unwrap<parser::DoConstruct>(c) != nullptr;
+}
+
 } // namespace Fortran::parser::omp

>From 6f78875afbd3be31811949e442eb3adf71146485 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Wed, 11 Feb 2026 12:25:34 -0600
Subject: [PATCH 3/3] Remove extraneous namespace qualifications

---
 flang/include/flang/Parser/openmp-utils.h | 26 +++++++++++------------
 flang/lib/Parser/openmp-utils.cpp         | 18 ++++++++--------
 2 files changed, 22 insertions(+), 22 deletions(-)

diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
index 1047cde5e3a5b..ada70e00f17b3 100644
--- a/flang/include/flang/Parser/openmp-utils.h
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -263,31 +263,31 @@ struct ExecutionPartIterator {
     Default = Into,
   };
 
-  using IteratorType = parser::Block::const_iterator;
+  using IteratorType = Block::const_iterator;
   using IteratorRange = llvm::iterator_range<IteratorType>;
 
   struct Construct {
     Construct(IteratorType b, IteratorType e,
-        const parser::ExecutionPartConstruct *c = nullptr)
+        const ExecutionPartConstruct *c = nullptr)
         : range(b, e), owner(c) {}
     template <typename C>
-    Construct(C &&r, const parser::ExecutionPartConstruct *c = nullptr)
+    Construct(C &&r, const ExecutionPartConstruct *c = nullptr)
         : range(r), owner(c) {}
     IteratorRange range;
-    const parser::ExecutionPartConstruct *owner = nullptr;
+    const ExecutionPartConstruct *owner = nullptr;
   };
 
   ExecutionPartIterator() = default;
 
   ExecutionPartIterator(IteratorType b, IteratorType e, Step s = Step::Default,
-      const parser::ExecutionPartConstruct *c = nullptr)
+      const ExecutionPartConstruct *c = nullptr)
       : stepping_(s) {
     stack_.emplace_back(b, e, c);
     adjust();
   }
   template <typename C>
   ExecutionPartIterator(C &&range, Step stepping = Step::Default,
-      const parser::ExecutionPartConstruct *construct = nullptr)
+      const ExecutionPartConstruct *construct = nullptr)
       : ExecutionPartIterator(range.begin(), range.end(), stepping, construct) {
   }
 
@@ -339,12 +339,12 @@ struct ExecutionPartIterator {
 template <typename Iterator = ExecutionPartIterator> struct ExecutionPartRange {
   using Step = typename Iterator::Step;
 
-  ExecutionPartRange(parser::Block::const_iterator begin,
-      parser::Block::const_iterator end, Step stepping = Step::Default,
-      const parser::ExecutionPartConstruct *owner = nullptr)
+  ExecutionPartRange(Block::const_iterator begin, Block::const_iterator end,
+      Step stepping = Step::Default,
+      const ExecutionPartConstruct *owner = nullptr)
       : begin_(begin, end, stepping, owner), end_() {}
-  ExecutionPartRange(const parser::Block &range, Step stepping = Step::Default,
-      const parser::ExecutionPartConstruct *owner = nullptr)
+  ExecutionPartRange(const Block &range, Step stepping = Step::Default,
+      const ExecutionPartConstruct *owner = nullptr)
       : ExecutionPartRange(range.begin(), range.end(), stepping, owner) {}
 
   Iterator begin() const { return begin_; }
@@ -357,7 +357,7 @@ template <typename Iterator = ExecutionPartIterator> struct ExecutionPartRange {
 struct LoopNestIterator : public ExecutionPartIterator {
   LoopNestIterator() = default;
   LoopNestIterator(IteratorType b, IteratorType e, Step s = Step::Default,
-      const parser::ExecutionPartConstruct *c = nullptr)
+      const ExecutionPartConstruct *c = nullptr)
       : ExecutionPartIterator(b, e, s, c) {
     adjust();
   }
@@ -369,7 +369,7 @@ struct LoopNestIterator : public ExecutionPartIterator {
   }
 
 private:
-  static bool isLoop(const parser::ExecutionPartConstruct &c);
+  static bool isLoop(const ExecutionPartConstruct &c);
 
   void adjust() {
     while (valid() && !isLoop(**this)) {
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
index 70482896b83af..9aa7014ccb057 100644
--- a/flang/lib/Parser/openmp-utils.cpp
+++ b/flang/lib/Parser/openmp-utils.cpp
@@ -216,10 +216,10 @@ void ExecutionPartIterator::step() {
   // position is a DO-loop or a loop construct, step into it.
   if (valid()) {
     IteratorType where{at()};
-    if (auto *loop{parser::omp::GetOmpLoop(*where)}) {
-      stack_.emplace_back(std::get<parser::Block>(loop->t), &*where);
-    } else if (auto *loop{parser::omp::GetDoConstruct(*where)}) {
-      stack_.emplace_back(std::get<parser::Block>(loop->t), &*where);
+    if (auto *loop{GetOmpLoop(*where)}) {
+      stack_.emplace_back(std::get<Block>(loop->t), &*where);
+    } else if (auto *loop{GetDoConstruct(*where)}) {
+      stack_.emplace_back(std::get<Block>(loop->t), &*where);
     } else {
       stack_.back().range =
           IteratorRange(std::next(where), stack_.back().range.end());
@@ -248,17 +248,17 @@ void ExecutionPartIterator::adjust() {
         stack_.back().range =
             IteratorRange(std::next(at()), stack_.back().range.end());
       }
-    } else if (auto *block{parser::omp::GetFortranBlockConstruct(*at())}) {
-      stack_.emplace_back(std::get<parser::Block>(block->t), &*at());
+    } else if (auto *block{GetFortranBlockConstruct(*at())}) {
+      stack_.emplace_back(std::get<Block>(block->t), &*at());
     } else {
       break;
     }
   }
 }
 
-bool LoopNestIterator::isLoop(const parser::ExecutionPartConstruct &c) {
-  return parser::Unwrap<parser::OpenMPLoopConstruct>(c) != nullptr ||
-      parser::Unwrap<parser::DoConstruct>(c) != nullptr;
+bool LoopNestIterator::isLoop(const ExecutionPartConstruct &c) {
+  return Unwrap<OpenMPLoopConstruct>(c) != nullptr ||
+      Unwrap<DoConstruct>(c) != nullptr;
 }
 
 } // namespace Fortran::parser::omp



More information about the flang-commits mailing list