[llvm-branch-commits] [flang] [mlir] [Flang][OpenMP] Lowering of host-evaluated clauses (PR #116219)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jan 10 04:06:36 PST 2025
================
@@ -55,6 +55,149 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
const ConstructQueue &queue,
ConstructQueue::const_iterator item);
+static void processHostEvalClauses(lower::AbstractConverter &converter,
+ semantics::SemanticsContext &semaCtx,
+ lower::StatementContext &stmtCtx,
+ lower::pft::Evaluation &eval,
+ mlir::Location loc);
+
+namespace {
+/// Structure holding information that is needed to pass host-evaluated
+/// information to later lowering stages.
+class HostEvalInfo {
+public:
+ // Allow this function access to private members in order to initialize them.
+ friend void ::processHostEvalClauses(lower::AbstractConverter &,
+ semantics::SemanticsContext &,
+ lower::StatementContext &,
+ lower::pft::Evaluation &,
+ mlir::Location);
+
+ /// Fill \c vars with values stored in \c ops.
+ ///
+ /// The order in which values are stored matches the one expected by \see
+ /// bindOperands().
+ void collectValues(llvm::SmallVectorImpl<mlir::Value> &vars) const {
+ vars.append(ops.loopLowerBounds);
+ vars.append(ops.loopUpperBounds);
+ vars.append(ops.loopSteps);
+
+ if (ops.numTeamsLower)
+ vars.push_back(ops.numTeamsLower);
+
+ if (ops.numTeamsUpper)
+ vars.push_back(ops.numTeamsUpper);
+
+ if (ops.numThreads)
+ vars.push_back(ops.numThreads);
+
+ if (ops.threadLimit)
+ vars.push_back(ops.threadLimit);
+ }
+
+ /// Update \c ops, replacing all values with the corresponding block argument
+ /// in \c args.
+ ///
+ /// The order in which values are stored in \c args is the same as the one
+ /// used by \see collectValues().
+ void bindOperands(llvm::ArrayRef<mlir::BlockArgument> args) {
+ assert(args.size() ==
+ ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
+ ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
+ (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
+ (ops.threadLimit ? 1 : 0) &&
+ "invalid block argument list");
+ int argIndex = 0;
+ for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i)
+ ops.loopLowerBounds[i] = args[argIndex++];
+
+ for (size_t i = 0; i < ops.loopUpperBounds.size(); ++i)
+ ops.loopUpperBounds[i] = args[argIndex++];
+
+ for (size_t i = 0; i < ops.loopSteps.size(); ++i)
+ ops.loopSteps[i] = args[argIndex++];
+
+ if (ops.numTeamsLower)
+ ops.numTeamsLower = args[argIndex++];
+
+ if (ops.numTeamsUpper)
+ ops.numTeamsUpper = args[argIndex++];
+
+ if (ops.numThreads)
+ ops.numThreads = args[argIndex++];
+
+ if (ops.threadLimit)
+ ops.threadLimit = args[argIndex++];
+ }
+
+ /// Update \p clauseOps and \p ivOut with the corresponding host-evaluated
+ /// values and Fortran symbols, respectively, if they have already been
+ /// initialized but not yet applied.
+ ///
+ /// \returns whether an update was performed. If not, these clauses were not
+ /// evaluated in the host device.
+ bool apply(mlir::omp::LoopNestOperands &clauseOps,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &ivOut) {
+ if (iv.empty() || loopNestApplied) {
+ loopNestApplied = true;
+ return false;
+ }
+
+ loopNestApplied = true;
+ clauseOps.loopLowerBounds = ops.loopLowerBounds;
+ clauseOps.loopUpperBounds = ops.loopUpperBounds;
+ clauseOps.loopSteps = ops.loopSteps;
+ ivOut.append(iv);
+ return true;
+ }
+
+ /// Update \p clauseOps with the corresponding host-evaluated values if they
+ /// have already been initialized but not yet applied.
+ ///
+ /// \returns whether an update was performed. If not, these clauses were not
+ /// evaluated in the host device.
+ bool apply(mlir::omp::ParallelOperands &clauseOps) {
+ if (!ops.numThreads || parallelApplied) {
+ parallelApplied = true;
+ return false;
+ }
+
+ parallelApplied = true;
+ clauseOps.numThreads = ops.numThreads;
+ return true;
+ }
+
+ /// Update \p clauseOps with the corresponding host-evaluated values if they
+ /// have already been initialized.
+ ///
+ /// \returns whether an update was performed. If not, these clauses were not
+ /// evaluated in the host device.
+ bool apply(mlir::omp::TeamsOperands &clauseOps) {
----------------
skatrak wrote:
This is because there can't be multiple `teams` in a single `target` region, basically. And the only legal way for a `teams` construct be somewhere nested inside another `teams` construct would be in a reverse-offload situation. In that case, executing this function for the inner `teams` would still not overwrite any of the previously applied values because every time a `target` region is crossed, a new `HostEvalInfo` is added to the stack.
We need to do these checks for `parallel` and `loop_nest` because there can be multiple nesting levels of these in a single target region and we only evaluate the top-level ones in the host. If we didn't check, the `clauseOps` would be overwritten with information from the innermost instance of these constructs.
https://github.com/llvm/llvm-project/pull/116219
More information about the llvm-branch-commits
mailing list