[llvm-branch-commits] [flang] [mlir] [flang][OpenMP][NFC] remove globals with mlir::StateStack (PR #144898)
Tom Eccles via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jun 19 07:05:12 PDT 2025
https://github.com/tblah created https://github.com/llvm/llvm-project/pull/144898
Idea suggested by @skatrak
>From 280e55d4355f100b7d3066fce3c0515b369fecce Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 18 Jun 2025 21:01:13 +0000
Subject: [PATCH] [flang][OpenMP][NFC] remove globals with mlir::StateStack
Idea suggested by @skatrak
---
flang/include/flang/Lower/AbstractConverter.h | 3 +
flang/lib/Lower/Bridge.cpp | 6 ++
flang/lib/Lower/OpenMP/OpenMP.cpp | 102 ++++++++++++------
mlir/include/mlir/Support/StateStack.h | 11 ++
4 files changed, 91 insertions(+), 31 deletions(-)
diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index 8ae68e143cd2f..de3e833f60699 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -26,6 +26,7 @@
namespace mlir {
class SymbolTable;
+class StateStack;
}
namespace fir {
@@ -361,6 +362,8 @@ class AbstractConverter {
/// functions in order to be in sync).
virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;
+ virtual mlir::StateStack &getStateStack() = 0;
+
private:
/// Options controlling lowering behavior.
const Fortran::lower::LoweringOptions &loweringOptions;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 64b16b3abe991..462ceb8dff736 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -78,6 +78,7 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Path.h"
#include "llvm/Target/TargetMachine.h"
+#include "mlir/Support/StateStack.h"
#include <optional>
#define DEBUG_TYPE "flang-lower-bridge"
@@ -1237,6 +1238,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; }
+ mlir::StateStack &getStateStack() override { return stateStack; }
+
/// Add the symbol to the local map and return `true`. If the symbol is
/// already in the map and \p forced is `false`, the map is not updated.
/// Instead the value `false` is returned.
@@ -6552,6 +6555,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// attribute since mlirSymbolTable must pro-actively be maintained when
/// new Symbol operations are created.
mlir::SymbolTable mlirSymbolTable;
+
+ /// Used to store context while recursing into regions during lowering.
+ mlir::StateStack stateStack;
};
} // namespace
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7ad8869597274..bff3321af2814 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -38,6 +38,7 @@
#include "flang/Support/OpenMP-utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Support/StateStack.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -200,9 +201,41 @@ class HostEvalInfo {
/// the handling of the outer region by keeping a stack of information
/// structures, but it will probably still require some further work to support
/// reverse offloading.
-static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo;
-static llvm::SmallVector<const parser::OpenMPSectionsConstruct *, 0>
- sectionsStack;
+class HostEvalInfoStackFrame
+ : public mlir::StateStackFrameBase<HostEvalInfoStackFrame> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostEvalInfoStackFrame)
+
+ HostEvalInfo info;
+};
+
+static HostEvalInfo *
+getHostEvalInfoStackTop(lower::AbstractConverter &converter) {
+ HostEvalInfoStackFrame *frame =
+ converter.getStateStack().getStackTop<HostEvalInfoStackFrame>();
+ return frame ? &frame->info : nullptr;
+}
+
+/// Stack frame for storing the OpenMPSectionsConstruct currently being
+/// processed so that it can be refered to when lowering the construct.
+class SectionsConstructStackFrame
+ : public mlir::StateStackFrameBase<SectionsConstructStackFrame> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SectionsConstructStackFrame)
+
+ explicit SectionsConstructStackFrame(
+ const parser::OpenMPSectionsConstruct §ionsConstruct)
+ : sectionsConstruct{sectionsConstruct} {}
+
+ const parser::OpenMPSectionsConstruct §ionsConstruct;
+};
+
+static const parser::OpenMPSectionsConstruct *
+getSectionsConstructStackTop(lower::AbstractConverter &converter) {
+ SectionsConstructStackFrame *frame =
+ converter.getStateStack().getStackTop<SectionsConstructStackFrame>();
+ return frame ? &frame->sectionsConstruct : nullptr;
+}
/// Bind symbols to their corresponding entry block arguments.
///
@@ -537,31 +570,32 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
if (!ompEval)
return;
- HostEvalInfo &hostInfo = hostEvalInfo.back();
+ HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter);
+ assert(hostInfo && "expected HOST_EVAL info structure");
switch (extractOmpDirective(*ompEval)) {
case OMPD_teams_distribute_parallel_do:
case OMPD_teams_distribute_parallel_do_simd:
- cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams_distribute_parallel_do:
case OMPD_target_teams_distribute_parallel_do_simd:
- cp.processNumTeams(stmtCtx, hostInfo.ops);
+ cp.processNumTeams(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_distribute_parallel_do:
case OMPD_distribute_parallel_do_simd:
- cp.processNumThreads(stmtCtx, hostInfo.ops);
+ cp.processNumThreads(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_distribute:
case OMPD_distribute_simd:
- cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+ cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
break;
case OMPD_teams:
- cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams:
- cp.processNumTeams(stmtCtx, hostInfo.ops);
+ cp.processNumTeams(stmtCtx, hostInfo->ops);
processSingleNestedIf([](Directive nestedDir) {
return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir);
});
@@ -569,22 +603,22 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
case OMPD_teams_distribute:
case OMPD_teams_distribute_simd:
- cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams_distribute:
case OMPD_target_teams_distribute_simd:
- cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
- cp.processNumTeams(stmtCtx, hostInfo.ops);
+ cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
+ cp.processNumTeams(stmtCtx, hostInfo->ops);
break;
case OMPD_teams_loop:
- cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams_loop:
- cp.processNumTeams(stmtCtx, hostInfo.ops);
+ cp.processNumTeams(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_loop:
- cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+ cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
break;
// Standalone 'target' case.
@@ -598,8 +632,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
}
};
- assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
-
const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
assert(ompEval &&
llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
@@ -1468,8 +1500,8 @@ static void genBodyOfTargetOp(
mlir::Region ®ion = targetOp.getRegion();
mlir::Block *entryBlock = genEntryBlock(firOpBuilder, args, region);
bindEntryBlockArgs(converter, targetOp, args);
- if (!hostEvalInfo.empty())
- hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs());
+ if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter))
+ hostEvalInfo->bindOperands(argIface.getHostEvalBlockArgs());
// Check if cloning the bounds introduced any dependency on the outer region.
// If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1708,7 +1740,8 @@ genLoopNestClauses(lower::AbstractConverter &converter,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
ClauseProcessor cp(converter, semaCtx, clauses);
- if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv))
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+ if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps, iv))
cp.processCollapse(loc, eval, clauseOps, iv);
clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
@@ -1753,7 +1786,8 @@ static void genParallelClauses(
cp.processAllocate(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
- if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps))
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+ if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps))
cp.processNumThreads(stmtCtx, clauseOps);
cp.processProcBind(clauseOps);
@@ -1818,16 +1852,17 @@ static void genTargetClauses(
llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processBare(clauseOps);
cp.processDefaultMap(stmtCtx, defaultMaps);
cp.processDepend(symTable, stmtCtx, clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms);
- if (!hostEvalInfo.empty()) {
+ if (hostEvalInfo) {
// Only process host_eval if compiling for the host device.
processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc);
- hostEvalInfo.back().collectValues(clauseOps.hostEvalVars);
+ hostEvalInfo->collectValues(clauseOps.hostEvalVars);
}
cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
@@ -1963,7 +1998,8 @@ static void genTeamsClauses(
cp.processAllocate(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
- if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) {
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+ if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps)) {
cp.processNumTeams(stmtCtx, clauseOps);
cp.processThreadLimit(stmtCtx, clauseOps);
}
@@ -2224,10 +2260,13 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
lower::pft::Evaluation &eval, mlir::Location loc,
const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
- assert(!sectionsStack.empty());
+ const parser::OpenMPSectionsConstruct *sectionsConstruct =
+ getSectionsConstructStackTop(converter);
+ assert(sectionsConstruct);
+
const auto §ionBlocks =
- std::get<parser::OmpSectionBlocks>(sectionsStack.back()->t);
- sectionsStack.pop_back();
+ std::get<parser::OmpSectionBlocks>(sectionsConstruct->t);
+ converter.getStateStack().stackPop();
mlir::omp::SectionsOperands clauseOps;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps,
@@ -2381,7 +2420,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
// Introduce a new host_eval information structure for this target region.
if (!isTargetDevice)
- hostEvalInfo.emplace_back();
+ converter.getStateStack().stackPush<HostEvalInfoStackFrame>();
mlir::omp::TargetOperands clauseOps;
DefaultMapsTy defaultMaps;
@@ -2508,7 +2547,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
// Remove the host_eval information structure created for this target region.
if (!isTargetDevice)
- hostEvalInfo.pop_back();
+ converter.getStateStack().stackPop();
return targetOp;
}
@@ -4235,7 +4274,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
eval, source, directive, clauses)};
- sectionsStack.push_back(§ionsConstruct);
+ converter.getStateStack().stackPush<SectionsConstructStackFrame>(
+ sectionsConstruct);
genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
queue.begin());
}
diff --git a/mlir/include/mlir/Support/StateStack.h b/mlir/include/mlir/Support/StateStack.h
index aca2375028246..9641a22c47776 100644
--- a/mlir/include/mlir/Support/StateStack.h
+++ b/mlir/include/mlir/Support/StateStack.h
@@ -83,6 +83,17 @@ class StateStack {
return WalkResult::advance();
}
+ /// Get the top instance of frame type `T` or nullptr if none are found
+ template <typename T>
+ T *getStackTop() {
+ T *top = nullptr;
+ stackWalk<T>([&](T &frame) -> mlir::WalkResult {
+ top = &frame;
+ return mlir::WalkResult::interrupt();
+ });
+ return top;
+ }
+
private:
SmallVector<std::unique_ptr<StateStackFrame>> stack;
};
More information about the llvm-branch-commits
mailing list