[Mlir-commits] [flang] [mlir] [flang][OpenMP][NFC] remove globals with mlir::StateStack (PR #144898)

Tom Eccles llvmlistbot at llvm.org
Tue Jun 24 09:46:31 PDT 2025


https://github.com/tblah updated https://github.com/llvm/llvm-project/pull/144898

>From 57c6d1498234d2fa363977c4b772a7258ce6d4c0 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 18 Jun 2025 21:01:13 +0000
Subject: [PATCH 1/2] [flang][OpenMP][NFC] remove globals with mlir::StateStack

Idea suggested by @skatrak
---
 flang/include/flang/Lower/AbstractConverter.h |   3 +
 flang/lib/Lower/Bridge.cpp                    |   6 ++
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 102 ++++++++++++------
 mlir/include/mlir/Support/StateStack.h        |  11 ++
 4 files changed, 91 insertions(+), 31 deletions(-)

diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index 8ae68e143cd2f..de3e833f60699 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -26,6 +26,7 @@
 
 namespace mlir {
 class SymbolTable;
+class StateStack;
 }
 
 namespace fir {
@@ -361,6 +362,8 @@ class AbstractConverter {
   /// functions in order to be in sync).
   virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;
 
+  virtual mlir::StateStack &getStateStack() = 0;
+
 private:
   /// Options controlling lowering behavior.
   const Fortran::lower::LoweringOptions &loweringOptions;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 64b16b3abe991..8506b9a984e58 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -69,6 +69,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Parser/Parser.h"
+#include "mlir/Support/StateStack.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
@@ -1237,6 +1238,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; }
 
+  mlir::StateStack &getStateStack() override { return stateStack; }
+
   /// Add the symbol to the local map and return `true`. If the symbol is
   /// already in the map and \p forced is `false`, the map is not updated.
   /// Instead the value `false` is returned.
@@ -6552,6 +6555,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// attribute since mlirSymbolTable must pro-actively be maintained when
   /// new Symbol operations are created.
   mlir::SymbolTable mlirSymbolTable;
+
+  /// Used to store context while recursing into regions during lowering.
+  mlir::StateStack stateStack;
 };
 
 } // namespace
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index ebd1d038716e4..7f5ece71e692c 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -39,6 +39,7 @@
 #include "flang/Support/OpenMP-utils.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Support/StateStack.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -198,9 +199,41 @@ class HostEvalInfo {
 /// the handling of the outer region by keeping a stack of information
 /// structures, but it will probably still require some further work to support
 /// reverse offloading.
-static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo;
-static llvm::SmallVector<const parser::OpenMPSectionsConstruct *, 0>
-    sectionsStack;
+class HostEvalInfoStackFrame
+    : public mlir::StateStackFrameBase<HostEvalInfoStackFrame> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostEvalInfoStackFrame)
+
+  HostEvalInfo info;
+};
+
+static HostEvalInfo *
+getHostEvalInfoStackTop(lower::AbstractConverter &converter) {
+  HostEvalInfoStackFrame *frame =
+      converter.getStateStack().getStackTop<HostEvalInfoStackFrame>();
+  return frame ? &frame->info : nullptr;
+}
+
+/// Stack frame for storing the OpenMPSectionsConstruct currently being
+/// processed so that it can be refered to when lowering the construct.
+class SectionsConstructStackFrame
+    : public mlir::StateStackFrameBase<SectionsConstructStackFrame> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SectionsConstructStackFrame)
+
+  explicit SectionsConstructStackFrame(
+      const parser::OpenMPSectionsConstruct &sectionsConstruct)
+      : sectionsConstruct{sectionsConstruct} {}
+
+  const parser::OpenMPSectionsConstruct §ionsConstruct;
+};
+
+static const parser::OpenMPSectionsConstruct *
+getSectionsConstructStackTop(lower::AbstractConverter &converter) {
+  SectionsConstructStackFrame *frame =
+      converter.getStateStack().getStackTop<SectionsConstructStackFrame>();
+  return frame ? &frame->sectionsConstruct : nullptr;
+}
 
 /// Bind symbols to their corresponding entry block arguments.
 ///
@@ -535,31 +568,32 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
     if (!ompEval)
       return;
 
-    HostEvalInfo &hostInfo = hostEvalInfo.back();
+    HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter);
+    assert(hostInfo && "expected HOST_EVAL info structure");
 
     switch (extractOmpDirective(*ompEval)) {
     case OMPD_teams_distribute_parallel_do:
     case OMPD_teams_distribute_parallel_do_simd:
-      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      cp.processThreadLimit(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_target_teams_distribute_parallel_do:
     case OMPD_target_teams_distribute_parallel_do_simd:
-      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      cp.processNumTeams(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_distribute_parallel_do:
     case OMPD_distribute_parallel_do_simd:
-      cp.processNumThreads(stmtCtx, hostInfo.ops);
+      cp.processNumThreads(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_distribute:
     case OMPD_distribute_simd:
-      cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+      cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
       break;
 
     case OMPD_teams:
-      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      cp.processThreadLimit(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_target_teams:
-      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      cp.processNumTeams(stmtCtx, hostInfo->ops);
       processSingleNestedIf([](Directive nestedDir) {
         return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir);
       });
@@ -567,22 +601,22 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
 
     case OMPD_teams_distribute:
     case OMPD_teams_distribute_simd:
-      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      cp.processThreadLimit(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_target_teams_distribute:
     case OMPD_target_teams_distribute_simd:
-      cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
-      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
+      cp.processNumTeams(stmtCtx, hostInfo->ops);
       break;
 
     case OMPD_teams_loop:
-      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      cp.processThreadLimit(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_target_teams_loop:
-      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      cp.processNumTeams(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_loop:
-      cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+      cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
       break;
 
     // Standalone 'target' case.
@@ -596,8 +630,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
     }
   };
 
-  assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
-
   const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
   assert(ompEval &&
          llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
@@ -1456,8 +1488,8 @@ static void genBodyOfTargetOp(
   mlir::Region &region = targetOp.getRegion();
   mlir::Block *entryBlock = genEntryBlock(firOpBuilder, args, region);
   bindEntryBlockArgs(converter, targetOp, args);
-  if (!hostEvalInfo.empty())
-    hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs());
+  if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter))
+    hostEvalInfo->bindOperands(argIface.getHostEvalBlockArgs());
 
   // Check if cloning the bounds introduced any dependency on the outer region.
   // If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1696,7 +1728,8 @@ genLoopNestClauses(lower::AbstractConverter &converter,
                    llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
   ClauseProcessor cp(converter, semaCtx, clauses);
 
-  if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv))
+  HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+  if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps, iv))
     cp.processCollapse(loc, eval, clauseOps, iv);
 
   clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
@@ -1741,7 +1774,8 @@ static void genParallelClauses(
   cp.processAllocate(clauseOps);
   cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
 
-  if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps))
+  HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+  if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps))
     cp.processNumThreads(stmtCtx, clauseOps);
 
   cp.processProcBind(clauseOps);
@@ -1806,16 +1840,17 @@ static void genTargetClauses(
     llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms,
     llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
     llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
+  HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processBare(clauseOps);
   cp.processDefaultMap(stmtCtx, defaultMaps);
   cp.processDepend(symTable, stmtCtx, clauseOps);
   cp.processDevice(stmtCtx, clauseOps);
   cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms);
-  if (!hostEvalInfo.empty()) {
+  if (hostEvalInfo) {
     // Only process host_eval if compiling for the host device.
     processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc);
-    hostEvalInfo.back().collectValues(clauseOps.hostEvalVars);
+    hostEvalInfo->collectValues(clauseOps.hostEvalVars);
   }
   cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
   cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
@@ -1952,7 +1987,8 @@ static void genTeamsClauses(
   cp.processAllocate(clauseOps);
   cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
 
-  if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) {
+  HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+  if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps)) {
     cp.processNumTeams(stmtCtx, clauseOps);
     cp.processThreadLimit(stmtCtx, clauseOps);
   }
@@ -2213,10 +2249,13 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
               lower::pft::Evaluation &eval, mlir::Location loc,
               const ConstructQueue &queue,
               ConstructQueue::const_iterator item) {
-  assert(!sectionsStack.empty());
+  const parser::OpenMPSectionsConstruct *sectionsConstruct =
+      getSectionsConstructStackTop(converter);
+  assert(sectionsConstruct);
+
   const auto &sectionBlocks =
-      std::get<parser::OmpSectionBlocks>(sectionsStack.back()->t);
-  sectionsStack.pop_back();
+      std::get<parser::OmpSectionBlocks>(sectionsConstruct->t);
+  converter.getStateStack().stackPop();
   mlir::omp::SectionsOperands clauseOps;
   llvm::SmallVector<const semantics::Symbol *> reductionSyms;
   genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps,
@@ -2370,7 +2409,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
   // Introduce a new host_eval information structure for this target region.
   if (!isTargetDevice)
-    hostEvalInfo.emplace_back();
+    converter.getStateStack().stackPush<HostEvalInfoStackFrame>();
 
   mlir::omp::TargetOperands clauseOps;
   DefaultMapsTy defaultMaps;
@@ -2497,7 +2536,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
   // Remove the host_eval information structure created for this target region.
   if (!isTargetDevice)
-    hostEvalInfo.pop_back();
+    converter.getStateStack().stackPop();
   return targetOp;
 }
 
@@ -3771,7 +3810,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
       buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
                           eval, source, directive, clauses)};
 
-  sectionsStack.push_back(&sectionsConstruct);
+  converter.getStateStack().stackPush<SectionsConstructStackFrame>(
+      sectionsConstruct);
   genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
                  queue.begin());
 }
diff --git a/mlir/include/mlir/Support/StateStack.h b/mlir/include/mlir/Support/StateStack.h
index ac70d05a3020a..44972fafe7fed 100644
--- a/mlir/include/mlir/Support/StateStack.h
+++ b/mlir/include/mlir/Support/StateStack.h
@@ -84,6 +84,17 @@ class StateStack {
     return WalkResult::advance();
   }
 
+  /// Get the top instance of frame type `T` or nullptr if none are found
+  template <typename T>
+  T *getStackTop() {
+    T *top = nullptr;
+    stackWalk<T>([&](T &frame) -> mlir::WalkResult {
+      top = &frame;
+      return mlir::WalkResult::interrupt();
+    });
+    return top;
+  }
+
 private:
   SmallVector<std::unique_ptr<StateStackFrame>> stack;
 };

>From 89f1ee4e313e7deb0ce283a48b44d35fbb5bc27f Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Fri, 20 Jun 2025 11:25:58 +0000
Subject: [PATCH 2/2] Review comments

---
 flang/lib/Lower/OpenMP/OpenMP.cpp | 15 +++++----------
 1 file changed, 5 insertions(+), 10 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7f5ece71e692c..60b6366c184d4 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -215,7 +215,7 @@ getHostEvalInfoStackTop(lower::AbstractConverter &converter) {
 }
 
 /// Stack frame for storing the OpenMPSectionsConstruct currently being
-/// processed so that it can be refered to when lowering the construct.
+/// processed so that it can be referred to when lowering the construct.
 class SectionsConstructStackFrame
     : public mlir::StateStackFrameBase<SectionsConstructStackFrame> {
 public:
@@ -1840,14 +1840,13 @@ static void genTargetClauses(
     llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms,
     llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
     llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
-  HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processBare(clauseOps);
   cp.processDefaultMap(stmtCtx, defaultMaps);
   cp.processDepend(symTable, stmtCtx, clauseOps);
   cp.processDevice(stmtCtx, clauseOps);
   cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms);
-  if (hostEvalInfo) {
+  if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter)) {
     // Only process host_eval if compiling for the host device.
     processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc);
     hostEvalInfo->collectValues(clauseOps.hostEvalVars);
@@ -2240,9 +2239,6 @@ genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       converter.getCurrentLocation(), clauseOps);
 }
 
-/// This breaks the normal prototype of the gen*Op functions: adding the
-/// sectionBlocks argument so that the enclosed section constructs can be
-/// lowered here with correct reduction symbol remapping.
 static mlir::omp::SectionsOp
 genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
               semantics::SemanticsContext &semaCtx,
@@ -2251,11 +2247,10 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
               ConstructQueue::const_iterator item) {
   const parser::OpenMPSectionsConstruct *sectionsConstruct =
       getSectionsConstructStackTop(converter);
-  assert(sectionsConstruct);
+  assert(sectionsConstruct && "Missing additional parsing information");
 
   const auto &sectionBlocks =
       std::get<parser::OmpSectionBlocks>(sectionsConstruct->t);
-  converter.getStateStack().stackPop();
   mlir::omp::SectionsOperands clauseOps;
   llvm::SmallVector<const semantics::Symbol *> reductionSyms;
   genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps,
@@ -3810,8 +3805,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
       buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
                           eval, source, directive, clauses)};
 
-  converter.getStateStack().stackPush<SectionsConstructStackFrame>(
-      sectionsConstruct);
+  mlir::SaveStateStack<SectionsConstructStackFrame> saveStateStack{
+      converter.getStateStack(), sectionsConstruct};
   genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
                  queue.begin());
 }



More information about the Mlir-commits mailing list