[llvm-branch-commits] [flang] [mlir] [Flang][OpenMP] Lowering of host-evaluated clauses (PR #116219)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Nov 14 04:45:17 PST 2024
https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/116219
This patch adds support for lowering OpenMP clauses and expressions attached to constructs nested inside of a target region that need to be evaluated in the host device. This is done through the use of the `OpenMP_HostEvalClause` `omp.target` set of operands and entry block arguments.
When lowering clauses for a target construct, a more involved `processHostEvalClauses()` function is called, which looks at the current and potentially other nested constructs in order to find and lower clauses that need to be processed outside of the `omp.target` operation under construction. This populates an instance of a global structure with the resulting MLIR values.
The resulting list of host-evaluated values is used to initialize the `host_eval` operands when constructing the `omp.target` operation, and then replaced with the corresponding block arguments after creating that operation's region.
Afterwards, while lowering nested operations, those that might potentially be evaluated in the host (e.g. `num_teams`, `thread_limit`, `num_threads` and `collapse`) check first whether there is an active global host-evaluated information structure and whether it holds values referring to these clauses. If that is the case, the stored values (referring to `omp.target` entry block arguments at that stage) are used instead of lowering these clauses again.
>From 6020805f93413e03e7e18aa167e29fea3e797c57 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 14 Nov 2024 12:24:15 +0000
Subject: [PATCH] [Flang][OpenMP] Lowering of host-evaluated clauses
This patch adds support for lowering OpenMP clauses and expressions attached to
constructs nested inside of a target region that need to be evaluated in the
host device. This is done through the use of the `OpenMP_HostEvalClause`
`omp.target` set of operands and entry block arguments.
When lowering clauses for a target construct, a more involved
`processHostEvalClauses()` function is called, which looks at the current and
potentially other nested constructs in order to find and lower clauses that
need to be processed outside of the `omp.target` operation under construction.
This populates an instance of a global structure with the resulting MLIR
values.
The resulting list of host-evaluated values is used to initialize the
`host_eval` operands when constructing the `omp.target` operation, and then
replaced with the corresponding block arguments after creating that operation's
region.
Afterwards, while lowering nested operations, those that might potentially be
evaluated in the host (e.g. `num_teams`, `thread_limit`, `num_threads` and
`collapse`) check first whether there is an active global host-evaluated
information structure and whether it holds values referring to these clauses.
If that is the case, the stored values (referring to `omp.target` entry block
arguments at that stage) are used instead of lowering clauses again.
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 458 ++++++++++++++++--
flang/test/Lower/OpenMP/host-eval.f90 | 138 ++++++
flang/test/Lower/OpenMP/target-spmd.f90 | 191 ++++++++
.../Dialect/OpenMP/OpenMPClauseOperands.h | 6 +
4 files changed, 764 insertions(+), 29 deletions(-)
create mode 100644 flang/test/Lower/OpenMP/host-eval.f90
create mode 100644 flang/test/Lower/OpenMP/target-spmd.f90
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 91f99ba4b0ca55..a206af77a2f51f 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -45,6 +45,19 @@ using namespace Fortran::lower::omp;
// Code generation helper functions
//===----------------------------------------------------------------------===//
+static void genOMPDispatch(lower::AbstractConverter &converter,
+ lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx,
+ lower::pft::Evaluation &eval, mlir::Location loc,
+ const ConstructQueue &queue,
+ ConstructQueue::const_iterator item);
+
+static void processHostEvalClauses(lower::AbstractConverter &converter,
+ semantics::SemanticsContext &semaCtx,
+ lower::StatementContext &stmtCtx,
+ lower::pft::Evaluation &eval,
+ mlir::Location loc);
+
namespace {
/// Structure holding the information needed to create and bind entry block
/// arguments associated to a single clause.
@@ -63,6 +76,7 @@ struct EntryBlockArgsEntry {
/// Structure holding the information needed to create and bind entry block
/// arguments associated to all clauses that can define them.
struct EntryBlockArgs {
+ llvm::ArrayRef<mlir::Value> hostEvalVars;
EntryBlockArgsEntry inReduction;
EntryBlockArgsEntry map;
EntryBlockArgsEntry priv;
@@ -85,18 +99,146 @@ struct EntryBlockArgs {
auto getVars() const {
return llvm::concat<const mlir::Value>(
- inReduction.vars, map.vars, priv.vars, reduction.vars,
+ hostEvalVars, inReduction.vars, map.vars, priv.vars, reduction.vars,
taskReduction.vars, useDeviceAddr.vars, useDevicePtr.vars);
}
};
+
+/// Structure holding information that is needed to pass host-evaluated
+/// information to later lowering stages.
+class HostEvalInfo {
+public:
+ // Allow this function access to private members in order to initialize them.
+ friend void ::processHostEvalClauses(lower::AbstractConverter &,
+ semantics::SemanticsContext &,
+ lower::StatementContext &,
+ lower::pft::Evaluation &,
+ mlir::Location);
+
+ /// Fill \c vars with values stored in \c ops.
+ ///
+ /// The order in which values are stored matches the one expected by \see
+ /// bindOperands().
+ void collectValues(llvm::SmallVectorImpl<mlir::Value> &vars) const {
+ vars.append(ops.loopLowerBounds);
+ vars.append(ops.loopUpperBounds);
+ vars.append(ops.loopSteps);
+
+ if (ops.numTeamsLower)
+ vars.push_back(ops.numTeamsLower);
+
+ if (ops.numTeamsUpper)
+ vars.push_back(ops.numTeamsUpper);
+
+ if (ops.numThreads)
+ vars.push_back(ops.numThreads);
+
+ if (ops.threadLimit)
+ vars.push_back(ops.threadLimit);
+ }
+
+ /// Update \c ops, replacing all values with the corresponding block argument
+ /// in \c args.
+ ///
+ /// The order in which values are stored in \c args is the same as the one
+ /// used by \see collectValues().
+ void bindOperands(llvm::ArrayRef<mlir::BlockArgument> args) {
+ assert(args.size() ==
+ ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
+ ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
+ (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
+ (ops.threadLimit ? 1 : 0) &&
+ "invalid block argument list");
+ int argIndex = 0;
+ for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i)
+ ops.loopLowerBounds[i] = args[argIndex++];
+
+ for (size_t i = 0; i < ops.loopUpperBounds.size(); ++i)
+ ops.loopUpperBounds[i] = args[argIndex++];
+
+ for (size_t i = 0; i < ops.loopSteps.size(); ++i)
+ ops.loopSteps[i] = args[argIndex++];
+
+ if (ops.numTeamsLower)
+ ops.numTeamsLower = args[argIndex++];
+
+ if (ops.numTeamsUpper)
+ ops.numTeamsUpper = args[argIndex++];
+
+ if (ops.numThreads)
+ ops.numThreads = args[argIndex++];
+
+ if (ops.threadLimit)
+ ops.threadLimit = args[argIndex++];
+ }
+
+ /// Update \p clauseOps and \p ivOut with the corresponding host-evaluated
+ /// values and Fortran symbols, respectively, if they have already been
+ /// initialized but not yet applied.
+ ///
+ /// \returns whether an update was performed. If not, these clauses were not
+ /// evaluated in the host device.
+ bool apply(mlir::omp::LoopNestOperands &clauseOps,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &ivOut) {
+ if (iv.empty() || loopNestApplied) {
+ loopNestApplied = true;
+ return false;
+ }
+
+ loopNestApplied = true;
+ clauseOps.loopLowerBounds = ops.loopLowerBounds;
+ clauseOps.loopUpperBounds = ops.loopUpperBounds;
+ clauseOps.loopSteps = ops.loopSteps;
+ ivOut.append(iv);
+ return true;
+ }
+
+ /// Update \p clauseOps with the corresponding host-evaluated values if they
+ /// have already been initialized but not yet applied.
+ ///
+ /// \returns whether an update was performed. If not, these clauses were not
+ /// evaluated in the host device.
+ bool apply(mlir::omp::ParallelOperands &clauseOps) {
+ if (!ops.numThreads || parallelApplied) {
+ parallelApplied = true;
+ return false;
+ }
+
+ parallelApplied = true;
+ clauseOps.numThreads = ops.numThreads;
+ return true;
+ }
+
+ /// Update \p clauseOps with the corresponding host-evaluated values if they
+ /// have already been initialized.
+ ///
+ /// \returns whether an update was performed. If not, these clauses were not
+ /// evaluated in the host device.
+ bool apply(mlir::omp::TeamsOperands &clauseOps) {
+ if (!ops.numTeamsLower && !ops.numTeamsUpper && !ops.threadLimit)
+ return false;
+
+ clauseOps.numTeamsLower = ops.numTeamsLower;
+ clauseOps.numTeamsUpper = ops.numTeamsUpper;
+ clauseOps.threadLimit = ops.threadLimit;
+ return true;
+ }
+
+private:
+ mlir::omp::HostEvaluatedOperands ops;
+ llvm::SmallVector<const semantics::Symbol *> iv;
+ bool loopNestApplied = false, parallelApplied = false;
+};
} // namespace
-static void genOMPDispatch(lower::AbstractConverter &converter,
- lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx,
- lower::pft::Evaluation &eval, mlir::Location loc,
- const ConstructQueue &queue,
- ConstructQueue::const_iterator item);
+/// Stack of \see HostEvalInfo to represent the current nest of \c omp.target
+/// operations being created.
+///
+/// The current implementation prevents nested 'target' regions from breaking
+/// the handling of the outer region by keeping a stack of information
+/// structures, but it will probably still require some further work to support
+/// reverse offloading.
+static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo;
/// Bind symbols to their corresponding entry block arguments.
///
@@ -219,6 +361,8 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter,
};
// Process in clause name alphabetical order to match block arguments order.
+ // Do not bind host_eval variables because they cannot be used inside of the
+ // corresponding region, except for very specific cases handled separately.
bindPrivateLike(args.inReduction.syms, args.inReduction.vars,
op.getInReductionBlockArgs());
bindMapLike(args.map.syms, op.getMapBlockArgs());
@@ -256,6 +400,246 @@ extractMappedBaseValues(llvm::ArrayRef<mlir::Value> vars,
});
}
+/// Get the directive enumeration value corresponding to the given OpenMP
+/// construct PFT node.
+llvm::omp::Directive
+extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) {
+ return common::visit(
+ common::visitors{
+ [](const parser::OpenMPAllocatorsConstruct &c) {
+ return llvm::omp::OMPD_allocators;
+ },
+ [](const parser::OpenMPAtomicConstruct &c) {
+ return llvm::omp::OMPD_atomic;
+ },
+ [](const parser::OpenMPBlockConstruct &c) {
+ return std::get<parser::OmpBlockDirective>(
+ std::get<parser::OmpBeginBlockDirective>(c.t).t)
+ .v;
+ },
+ [](const parser::OpenMPCriticalConstruct &c) {
+ return llvm::omp::OMPD_critical;
+ },
+ [](const parser::OpenMPDeclarativeAllocate &c) {
+ return llvm::omp::OMPD_allocate;
+ },
+ [](const parser::OpenMPExecutableAllocate &c) {
+ return llvm::omp::OMPD_allocate;
+ },
+ [](const parser::OpenMPLoopConstruct &c) {
+ return std::get<parser::OmpLoopDirective>(
+ std::get<parser::OmpBeginLoopDirective>(c.t).t)
+ .v;
+ },
+ [](const parser::OpenMPSectionConstruct &c) {
+ return llvm::omp::OMPD_section;
+ },
+ [](const parser::OpenMPSectionsConstruct &c) {
+ return std::get<parser::OmpSectionsDirective>(
+ std::get<parser::OmpBeginSectionsDirective>(c.t).t)
+ .v;
+ },
+ [](const parser::OpenMPStandaloneConstruct &c) {
+ return common::visit(
+ common::visitors{
+ [](const parser::OpenMPSimpleStandaloneConstruct &c) {
+ return std::get<parser::OmpSimpleStandaloneDirective>(c.t)
+ .v;
+ },
+ [](const parser::OpenMPFlushConstruct &c) {
+ return llvm::omp::OMPD_flush;
+ },
+ [](const parser::OpenMPCancelConstruct &c) {
+ return llvm::omp::OMPD_cancel;
+ },
+ [](const parser::OpenMPCancellationPointConstruct &c) {
+ return llvm::omp::OMPD_cancellation_point;
+ },
+ [](const parser::OpenMPDepobjConstruct &c) {
+ return llvm::omp::OMPD_depobj;
+ }},
+ c.u);
+ }},
+ ompConstruct.u);
+}
+
+/// Populate the global \see hostEvalInfo after processing clauses for the given
+/// \p eval OpenMP target construct, or nested constructs, if these must be
+/// evaluated outside of the target region per the spec.
+///
+/// In particular, this will ensure that in 'target teams' and equivalent nested
+/// constructs, the \c thread_limit and \c num_teams clauses will be evaluated
+/// in the host. Additionally, loop bounds, steps and the \c num_threads clause
+/// will also be evaluated in the host if a target SPMD construct is detected
+/// (i.e. 'target teams distribute parallel do [simd]' or equivalent nesting).
+///
+/// The result, stored as a global, is intended to be used to populate the \c
+/// host_eval operands of the associated \c omp.target operation, and also to be
+/// checked and used by later lowering steps to populate the corresponding
+/// operands of the \c omp.teams, \c omp.parallel or \c omp.loop_nest
+/// operations.
+static void processHostEvalClauses(lower::AbstractConverter &converter,
+ semantics::SemanticsContext &semaCtx,
+ lower::StatementContext &stmtCtx,
+ lower::pft::Evaluation &eval,
+ mlir::Location loc) {
+ // Obtain the list of clauses of the given OpenMP block or loop construct
+ // evaluation. Other evaluations passed to this lambda keep `clauses`
+ // unchanged.
+ auto extractClauses = [&semaCtx](lower::pft::Evaluation &eval,
+ List<Clause> &clauses) {
+ const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+ if (!ompEval)
+ return;
+
+ const parser::OmpClauseList *beginClauseList = nullptr;
+ const parser::OmpClauseList *endClauseList = nullptr;
+ common::visit(
+ common::visitors{
+ [&](const parser::OpenMPBlockConstruct &ompConstruct) {
+ const auto &beginDirective =
+ std::get<parser::OmpBeginBlockDirective>(ompConstruct.t);
+ beginClauseList =
+ &std::get<parser::OmpClauseList>(beginDirective.t);
+ endClauseList = &std::get<parser::OmpClauseList>(
+ std::get<parser::OmpEndBlockDirective>(ompConstruct.t).t);
+ },
+ [&](const parser::OpenMPLoopConstruct &ompConstruct) {
+ const auto &beginDirective =
+ std::get<parser::OmpBeginLoopDirective>(ompConstruct.t);
+ beginClauseList =
+ &std::get<parser::OmpClauseList>(beginDirective.t);
+
+ if (auto &endDirective =
+ std::get<std::optional<parser::OmpEndLoopDirective>>(
+ ompConstruct.t))
+ endClauseList =
+ &std::get<parser::OmpClauseList>(endDirective->t);
+ },
+ [&](const auto &) {}},
+ ompEval->u);
+
+ assert(beginClauseList && "expected begin directive");
+ clauses.append(makeClauses(*beginClauseList, semaCtx));
+
+ if (endClauseList)
+ clauses.append(makeClauses(*endClauseList, semaCtx));
+ };
+
+ // Return the directive that is immediately nested inside of the given
+ // `parent` evaluation, if it is its only non-end-statement nested evaluation
+ // and it represents an OpenMP construct.
+ auto extractOnlyOmpNestedDir = [](lower::pft::Evaluation &parent)
+ -> std::optional<llvm::omp::Directive> {
+ if (!parent.hasNestedEvaluations())
+ return std::nullopt;
+
+ llvm::omp::Directive dir;
+ auto &nested = parent.getFirstNestedEvaluation();
+ if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>())
+ dir = extractOmpDirective(*ompEval);
+ else
+ return std::nullopt;
+
+ for (auto &sibling : parent.getNestedEvaluations())
+ if (&sibling != &nested && !sibling.isEndStmt())
+ return std::nullopt;
+
+ return dir;
+ };
+
+ // Process the given evaluation assuming it's part of a 'target' construct or
+ // captured by one, and store results in the global `hostEvalInfo`.
+ std::function<void(lower::pft::Evaluation &, const List<Clause> &)>
+ processEval;
+ processEval = [&](lower::pft::Evaluation &eval, const List<Clause> &clauses) {
+ using namespace llvm::omp;
+ ClauseProcessor cp(converter, semaCtx, clauses);
+
+ // Call `processEval` recursively with the immediately nested evaluation and
+ // its corresponding clauses if there is a single nested evaluation
+ // representing an OpenMP directive that passes the given test.
+ auto processSingleNestedIf = [&](llvm::function_ref<bool(Directive)> test) {
+ std::optional<Directive> nestedDir = extractOnlyOmpNestedDir(eval);
+ if (!nestedDir || !test(*nestedDir))
+ return;
+
+ lower::pft::Evaluation &nestedEval = eval.getFirstNestedEvaluation();
+ List<lower::omp::Clause> nestedClauses;
+ extractClauses(nestedEval, nestedClauses);
+ processEval(nestedEval, nestedClauses);
+ };
+
+ const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+ if (!ompEval)
+ return;
+
+ HostEvalInfo &hostInfo = hostEvalInfo.back();
+
+ switch (extractOmpDirective(*ompEval)) {
+ // Cases where 'teams' and target SPMD clauses might be present.
+ case OMPD_teams_distribute_parallel_do:
+ case OMPD_teams_distribute_parallel_do_simd:
+ cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ [[fallthrough]];
+ case OMPD_target_teams_distribute_parallel_do:
+ case OMPD_target_teams_distribute_parallel_do_simd:
+ cp.processNumTeams(stmtCtx, hostInfo.ops);
+ [[fallthrough]];
+ case OMPD_distribute_parallel_do:
+ case OMPD_distribute_parallel_do_simd:
+ cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+ cp.processNumThreads(stmtCtx, hostInfo.ops);
+ break;
+
+ // Cases where 'teams' clauses might be present, and target SPMD is
+ // possible by looking at nested evaluations.
+ case OMPD_teams:
+ cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ [[fallthrough]];
+ case OMPD_target_teams:
+ cp.processNumTeams(stmtCtx, hostInfo.ops);
+ processSingleNestedIf([](Directive nestedDir) {
+ return nestedDir == OMPD_distribute_parallel_do ||
+ nestedDir == OMPD_distribute_parallel_do_simd;
+ });
+ break;
+
+ // Cases where only 'teams' host-evaluated clauses might be present.
+ case OMPD_teams_distribute:
+ case OMPD_teams_distribute_simd:
+ cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ [[fallthrough]];
+ case OMPD_target_teams_distribute:
+ case OMPD_target_teams_distribute_simd:
+ cp.processNumTeams(stmtCtx, hostInfo.ops);
+ break;
+
+ // Standalone 'target' case.
+ case OMPD_target: {
+ processSingleNestedIf(
+ [](Directive nestedDir) { return topTeamsSet.test(nestedDir); });
+ break;
+ }
+ default:
+ break;
+ }
+ };
+
+ assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
+
+ const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+ assert(ompEval &&
+ llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
+ "expected TARGET construct evaluation");
+
+ // Use the whole list of clauses passed to the construct here, rather than the
+ // ones only applied to omp.target.
+ List<lower::omp::Clause> clauses;
+ extractClauses(eval, clauses);
+ processEval(eval, clauses);
+}
+
static lower::pft::Evaluation *
getCollapsedLoopEval(lower::pft::Evaluation &eval, int collapseValue) {
// Return the Evaluation of the innermost collapsed loop, or the current one
@@ -638,11 +1022,11 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Type> types;
llvm::SmallVector<mlir::Location> locs;
- unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() +
- args.priv.vars.size() + args.reduction.vars.size() +
- args.taskReduction.vars.size() +
- args.useDeviceAddr.vars.size() +
- args.useDevicePtr.vars.size();
+ unsigned numVars =
+ args.hostEvalVars.size() + args.inReduction.vars.size() +
+ args.map.vars.size() + args.priv.vars.size() +
+ args.reduction.vars.size() + args.taskReduction.vars.size() +
+ args.useDeviceAddr.vars.size() + args.useDevicePtr.vars.size();
types.reserve(numVars);
locs.reserve(numVars);
@@ -655,6 +1039,7 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter,
// Populate block arguments in clause name alphabetical order to match
// expected order by the BlockArgOpenMPOpInterface.
+ extractTypeLoc(args.hostEvalVars);
extractTypeLoc(args.inReduction.vars);
extractTypeLoc(args.map.vars);
extractTypeLoc(args.priv.vars);
@@ -991,12 +1376,15 @@ static void genBodyOfTargetOp(
mlir::omp::TargetOp &targetOp, const EntryBlockArgs &args,
const mlir::Location ¤tLocation, const ConstructQueue &queue,
ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
+ assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp);
mlir::Region ®ion = targetOp.getRegion();
mlir::Block *entryBlock = genEntryBlock(converter, args, region);
bindEntryBlockArgs(converter, targetOp, args);
+ hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs());
// Check if cloning the bounds introduced any dependency on the outer region.
// If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1172,7 +1560,10 @@ genLoopNestClauses(lower::AbstractConverter &converter,
mlir::Location loc, mlir::omp::LoopNestOperands &clauseOps,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
ClauseProcessor cp(converter, semaCtx, clauses);
- cp.processCollapse(loc, eval, clauseOps, iv);
+
+ if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv))
+ cp.processCollapse(loc, eval, clauseOps, iv);
+
clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
}
@@ -1202,7 +1593,10 @@ static void genParallelClauses(
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
- cp.processNumThreads(stmtCtx, clauseOps);
+
+ if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps))
+ cp.processNumThreads(stmtCtx, clauseOps);
+
cp.processProcBind(clauseOps);
cp.processReduction(loc, clauseOps, reductionSyms);
}
@@ -1249,23 +1643,24 @@ static void genSingleClauses(lower::AbstractConverter &converter,
static void genTargetClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
- lower::StatementContext &stmtCtx, const List<Clause> &clauses,
- mlir::Location loc, bool processHostOnlyClauses,
+ lower::StatementContext &stmtCtx, lower::pft::Evaluation &eval,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TargetOperands &clauseOps,
llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
+ assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
+
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processDepend(clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processHasDeviceAddr(clauseOps, hasDeviceAddrSyms);
+ processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc);
+ hostEvalInfo.back().collectValues(clauseOps.hostEvalVars);
cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
cp.processMap(loc, stmtCtx, clauseOps, &mapSyms);
-
- if (processHostOnlyClauses)
- cp.processNowait(clauseOps);
-
+ cp.processNowait(clauseOps);
cp.processThreadLimit(stmtCtx, clauseOps);
cp.processTODO<clause::Allocate, clause::Defaultmap, clause::Firstprivate,
@@ -1367,10 +1762,13 @@ static void genTeamsClauses(lower::AbstractConverter &converter,
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
- cp.processNumTeams(stmtCtx, clauseOps);
- cp.processThreadLimit(stmtCtx, clauseOps);
- // TODO Support delayed privatization.
+ if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) {
+ cp.processNumTeams(stmtCtx, clauseOps);
+ cp.processThreadLimit(stmtCtx, clauseOps);
+ }
+
+ // TODO Support delayed privatization.
cp.processTODO<clause::Reduction>(loc, llvm::omp::Directive::OMPD_teams);
}
@@ -1710,16 +2108,14 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
lower::StatementContext stmtCtx;
- bool processHostOnlyClauses =
- !llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp())
- .getIsTargetDevice();
+ // Introduce a new host_eval information structure for this target region.
+ hostEvalInfo.emplace_back();
mlir::omp::TargetOperands clauseOps;
llvm::SmallVector<const semantics::Symbol *> mapSyms, isDevicePtrSyms,
hasDeviceAddrSyms;
- genTargetClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
- processHostOnlyClauses, clauseOps, hasDeviceAddrSyms,
- isDevicePtrSyms, mapSyms);
+ genTargetClauses(converter, semaCtx, stmtCtx, eval, item->clauses, loc,
+ clauseOps, hasDeviceAddrSyms, isDevicePtrSyms, mapSyms);
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
/*shouldCollectPreDeterminedSymbols=*/
@@ -1836,6 +2232,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
extractMappedBaseValues(clauseOps.mapVars, mapBaseValues);
EntryBlockArgs args;
+ args.hostEvalVars = clauseOps.hostEvalVars;
// TODO: Add in_reduction syms and vars.
args.map.syms = mapSyms;
args.map.vars = mapBaseValues;
@@ -1844,6 +2241,9 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
genBodyOfTargetOp(converter, symTable, semaCtx, eval, targetOp, args, loc,
queue, item, dsp);
+
+ // Remove the host_eval information structure created for this target region.
+ hostEvalInfo.pop_back();
return targetOp;
}
diff --git a/flang/test/Lower/OpenMP/host-eval.f90 b/flang/test/Lower/OpenMP/host-eval.f90
new file mode 100644
index 00000000000000..e23389ee346db7
--- /dev/null
+++ b/flang/test/Lower/OpenMP/host-eval.f90
@@ -0,0 +1,138 @@
+! The "thread_limit" clause was added to the "target" construct in OpenMP 5.1.
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 -fopenmp-is-target-device %s -o - | FileCheck %s
+
+! CHECK-LABEL: func.func @_QPteams
+subroutine teams()
+ ! CHECK: omp.target
+ ! CHECK-SAME: host_eval(%{{.*}} -> %[[NUM_TEAMS:.*]], %{{.*}} -> %[[THREAD_LIMIT:.*]] : i32, i32)
+ !$omp target
+
+ ! CHECK: omp.teams
+ ! CHECK-SAME: num_teams( to %[[NUM_TEAMS]] : i32) thread_limit(%[[THREAD_LIMIT]] : i32)
+ !$omp teams num_teams(1) thread_limit(2)
+ call foo()
+ !$omp end teams
+
+ !$omp end target
+
+ ! CHECK: omp.teams
+ ! CHECK-SAME: num_teams({{.*}}) thread_limit({{.*}}) {
+ !$omp teams num_teams(1) thread_limit(2)
+ call foo()
+ !$omp end teams
+end subroutine teams
+
+! CHECK-LABEL: func.func @_QPdistribute_parallel_do
+subroutine distribute_parallel_do()
+ ! CHECK: omp.target
+ ! CHECK-SAME: host_eval(%{{.*}} -> %[[LB:.*]], %{{.*}} -> %[[UB:.*]], %{{.*}} -> %[[STEP:.*]], %{{.*}} -> %[[NUM_THREADS:.*]] : i32, i32, i32, i32)
+
+ ! CHECK: omp.teams
+ !$omp target teams
+
+ ! CHECK: omp.parallel
+ ! CHECK-SAME: num_threads(%[[NUM_THREADS]] : i32)
+
+ ! CHECK: omp.distribute
+ ! CHECK-NEXT: omp.wsloop
+ ! CHECK-NEXT: omp.loop_nest
+ ! CHECK-SAME: (%{{.*}}) : i32 = (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]])
+ !$omp distribute parallel do num_threads(1)
+ do i=1,10
+ call foo()
+ end do
+ !$omp end distribute parallel do
+ !$omp end target teams
+
+ ! CHECK: omp.target
+ ! CHECK-NOT: host_eval({{.*}})
+ ! CHECK-SAME: {
+ ! CHECK: omp.teams
+ !$omp target teams
+ call foo() !< Prevents this from being SPMD.
+
+ ! CHECK: omp.parallel
+ ! CHECK-SAME: num_threads({{.*}})
+ ! CHECK: omp.distribute
+ ! CHECK-NEXT: omp.wsloop
+ !$omp distribute parallel do num_threads(1)
+ do i=1,10
+ call foo()
+ end do
+ !$omp end distribute parallel do
+ !$omp end target teams
+
+ ! CHECK: omp.teams
+ !$omp teams
+
+ ! CHECK: omp.parallel
+ ! CHECK-SAME: num_threads({{.*}})
+ ! CHECK: omp.distribute
+ ! CHECK-NEXT: omp.wsloop
+ !$omp distribute parallel do num_threads(1)
+ do i=1,10
+ call foo()
+ end do
+ !$omp end distribute parallel do
+ !$omp end teams
+end subroutine distribute_parallel_do
+
+! CHECK-LABEL: func.func @_QPdistribute_parallel_do_simd
+subroutine distribute_parallel_do_simd()
+ ! CHECK: omp.target
+ ! CHECK-SAME: host_eval(%{{.*}} -> %[[LB:.*]], %{{.*}} -> %[[UB:.*]], %{{.*}} -> %[[STEP:.*]], %{{.*}} -> %[[NUM_THREADS:.*]] : i32, i32, i32, i32)
+
+ ! CHECK: omp.teams
+ !$omp target teams
+
+ ! CHECK: omp.parallel
+ ! CHECK-SAME: num_threads(%[[NUM_THREADS]] : i32)
+
+ ! CHECK: omp.distribute
+ ! CHECK-NEXT: omp.wsloop
+ ! CHECK-NEXT: omp.simd
+ ! CHECK-NEXT: omp.loop_nest
+
+ ! CHECK-SAME: (%{{.*}}) : i32 = (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]])
+ !$omp distribute parallel do simd num_threads(1)
+ do i=1,10
+ call foo()
+ end do
+ !$omp end distribute parallel do simd
+ !$omp end target teams
+
+ ! CHECK: omp.target
+ ! CHECK-NOT: host_eval({{.*}})
+ ! CHECK-SAME: {
+ ! CHECK: omp.teams
+ !$omp target teams
+ call foo() !< Prevents this from being SPMD.
+
+ ! CHECK: omp.parallel
+ ! CHECK-SAME: num_threads({{.*}})
+ ! CHECK: omp.distribute
+ ! CHECK-NEXT: omp.wsloop
+ ! CHECK-NEXT: omp.simd
+ !$omp distribute parallel do simd num_threads(1)
+ do i=1,10
+ call foo()
+ end do
+ !$omp end distribute parallel do simd
+ !$omp end target teams
+
+ ! CHECK: omp.teams
+ !$omp teams
+
+ ! CHECK: omp.parallel
+ ! CHECK-SAME: num_threads({{.*}})
+ ! CHECK: omp.distribute
+ ! CHECK-NEXT: omp.wsloop
+ ! CHECK-NEXT: omp.simd
+ !$omp distribute parallel do simd num_threads(1)
+ do i=1,10
+ call foo()
+ end do
+ !$omp end distribute parallel do simd
+ !$omp end teams
+end subroutine distribute_parallel_do_simd
diff --git a/flang/test/Lower/OpenMP/target-spmd.f90 b/flang/test/Lower/OpenMP/target-spmd.f90
new file mode 100644
index 00000000000000..43613819ccc8e9
--- /dev/null
+++ b/flang/test/Lower/OpenMP/target-spmd.f90
@@ -0,0 +1,191 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+! CHECK-LABEL: func.func @_QPdistribute_parallel_do_generic() {
+subroutine distribute_parallel_do_generic()
+ ! CHECK: omp.target
+ ! CHECK-NOT: host_eval({{.*}})
+ ! CHECK-SAME: {
+ !$omp target
+ !$omp teams
+ !$omp distribute parallel do
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end distribute parallel do
+ call bar() !< Prevents this from being SPMD.
+ !$omp end teams
+ !$omp end target
+
+ ! CHECK: omp.target
+ ! CHECK-NOT: host_eval({{.*}})
+ ! CHECK-SAME: {
+ !$omp target teams
+ !$omp distribute parallel do
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end distribute parallel do
+ call bar() !< Prevents this from being SPMD.
+ !$omp end target teams
+
+ ! CHECK: omp.target
+ ! CHECK-NOT: host_eval({{.*}})
+ ! CHECK-SAME: {
+ !$omp target teams
+ !$omp distribute parallel do
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end distribute parallel do
+
+ !$omp distribute parallel do
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end distribute parallel do
+ !$omp end target teams
+end subroutine distribute_parallel_do_generic
+
+! CHECK-LABEL: func.func @_QPdistribute_parallel_do_spmd() {
+subroutine distribute_parallel_do_spmd()
+ ! CHECK: omp.target
+ ! CHECK-SAME: host_eval({{.*}})
+ !$omp target
+ !$omp teams
+ !$omp distribute parallel do
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end distribute parallel do
+ !$omp end teams
+ !$omp end target
+
+ ! CHECK: omp.target
+ ! CHECK-SAME: host_eval({{.*}})
+ !$omp target teams
+ !$omp distribute parallel do
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end distribute parallel do
+ !$omp end target teams
+end subroutine distribute_parallel_do_spmd
+
+! CHECK-LABEL: func.func @_QPdistribute_parallel_do_simd_generic() {
+subroutine distribute_parallel_do_simd_generic()
+ ! CHECK: omp.target
+ ! CHECK-NOT: host_eval({{.*}})
+ ! CHECK-SAME: {
+ !$omp target
+ !$omp teams
+ !$omp distribute parallel do simd
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end distribute parallel do simd
+ call bar() !< Prevents this from being SPMD.
+ !$omp end teams
+ !$omp end target
+
+ ! CHECK: omp.target
+ ! CHECK-NOT: host_eval({{.*}})
+ ! CHECK-SAME: {
+ !$omp target teams
+ !$omp distribute parallel do simd
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end distribute parallel do simd
+ call bar() !< Prevents this from being SPMD.
+ !$omp end target teams
+
+ ! CHECK: omp.target
+ ! CHECK-NOT: host_eval({{.*}})
+ ! CHECK-SAME: {
+ !$omp target teams
+ !$omp distribute parallel do simd
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end distribute parallel do simd
+
+ !$omp distribute parallel do simd
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end distribute parallel do simd
+ !$omp end target teams
+end subroutine distribute_parallel_do_simd_generic
+
+! CHECK-LABEL: func.func @_QPdistribute_parallel_do_simd_spmd() {
+subroutine distribute_parallel_do_simd_spmd()
+ ! CHECK: omp.target
+ ! CHECK-SAME: host_eval({{.*}})
+ !$omp target
+ !$omp teams
+ !$omp distribute parallel do simd
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end distribute parallel do simd
+ !$omp end teams
+ !$omp end target
+
+ ! CHECK: omp.target
+ ! CHECK-SAME: host_eval({{.*}})
+ !$omp target teams
+ !$omp distribute parallel do simd
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end distribute parallel do simd
+ !$omp end target teams
+end subroutine distribute_parallel_do_simd_spmd
+
+! CHECK-LABEL: func.func @_QPteams_distribute_parallel_do_spmd() {
+subroutine teams_distribute_parallel_do_spmd()
+ ! CHECK: omp.target
+ ! CHECK-SAME: host_eval({{.*}})
+ !$omp target
+ !$omp teams distribute parallel do
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end teams distribute parallel do
+ !$omp end target
+end subroutine teams_distribute_parallel_do_spmd
+
+! CHECK-LABEL: func.func @_QPteams_distribute_parallel_do_simd_spmd() {
+subroutine teams_distribute_parallel_do_simd_spmd()
+ ! CHECK: omp.target
+ ! CHECK-SAME: host_eval({{.*}})
+ !$omp target
+ !$omp teams distribute parallel do simd
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end teams distribute parallel do simd
+ !$omp end target
+end subroutine teams_distribute_parallel_do_simd_spmd
+
+! CHECK-LABEL: func.func @_QPtarget_teams_distribute_parallel_do_spmd() {
+subroutine target_teams_distribute_parallel_do_spmd()
+ ! CHECK: omp.target
+ ! CHECK-SAME: host_eval({{.*}})
+ !$omp target teams distribute parallel do
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end target teams distribute parallel do
+end subroutine target_teams_distribute_parallel_do_spmd
+
+! CHECK-LABEL: func.func @_QPtarget_teams_distribute_parallel_do_simd_spmd() {
+subroutine target_teams_distribute_parallel_do_simd_spmd()
+ ! CHECK: omp.target
+ ! CHECK-SAME: host_eval({{.*}})
+ !$omp target teams distribute parallel do simd
+ do i = 1, 10
+ call foo(i)
+ end do
+ !$omp end target teams distribute parallel do simd
+end subroutine target_teams_distribute_parallel_do_simd_spmd
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index 1247a871f93c6d..f9a85626a3f149 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -41,6 +41,12 @@ struct DeviceTypeClauseOps {
// Extra operation operand structures.
//===----------------------------------------------------------------------===//
+/// Clauses that correspond to operations other than omp.target, but might have
+/// to be evaluated outside of a parent target region.
+using HostEvaluatedOperands =
+ detail::Clauses<LoopRelatedClauseOps, NumTeamsClauseOps,
+ NumThreadsClauseOps, ThreadLimitClauseOps>;
+
// TODO: Add `indirect` clause.
using DeclareTargetOperands = detail::Clauses<DeviceTypeClauseOps>;
More information about the llvm-branch-commits
mailing list