[llvm-branch-commits] [flang] [mlir] [Flang][OpenMP] Lowering of host-evaluated clauses (PR #116219)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jan 10 04:06:36 PST 2025


================
@@ -55,6 +55,149 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                            const ConstructQueue &queue,
                            ConstructQueue::const_iterator item);
 
+static void processHostEvalClauses(lower::AbstractConverter &converter,
+                                   semantics::SemanticsContext &semaCtx,
+                                   lower::StatementContext &stmtCtx,
+                                   lower::pft::Evaluation &eval,
+                                   mlir::Location loc);
+
+namespace {
+/// Structure holding information that is needed to pass host-evaluated
+/// information to later lowering stages.
+class HostEvalInfo {
+public:
+  // Allow this function access to private members in order to initialize them.
+  friend void ::processHostEvalClauses(lower::AbstractConverter &,
+                                       semantics::SemanticsContext &,
+                                       lower::StatementContext &,
+                                       lower::pft::Evaluation &,
+                                       mlir::Location);
+
+  /// Fill \c vars with values stored in \c ops.
+  ///
+  /// The order in which values are stored matches the one expected by \see
+  /// bindOperands().
+  void collectValues(llvm::SmallVectorImpl<mlir::Value> &vars) const {
+    vars.append(ops.loopLowerBounds);
+    vars.append(ops.loopUpperBounds);
+    vars.append(ops.loopSteps);
+
+    if (ops.numTeamsLower)
+      vars.push_back(ops.numTeamsLower);
+
+    if (ops.numTeamsUpper)
+      vars.push_back(ops.numTeamsUpper);
+
+    if (ops.numThreads)
+      vars.push_back(ops.numThreads);
+
+    if (ops.threadLimit)
+      vars.push_back(ops.threadLimit);
+  }
+
+  /// Update \c ops, replacing all values with the corresponding block argument
+  /// in \c args.
+  ///
+  /// The order in which values are stored in \c args is the same as the one
+  /// used by \see collectValues().
+  void bindOperands(llvm::ArrayRef<mlir::BlockArgument> args) {
+    assert(args.size() ==
+               ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
+                   ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
+                   (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
+                   (ops.threadLimit ? 1 : 0) &&
+           "invalid block argument list");
+    int argIndex = 0;
+    for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i)
+      ops.loopLowerBounds[i] = args[argIndex++];
+
+    for (size_t i = 0; i < ops.loopUpperBounds.size(); ++i)
+      ops.loopUpperBounds[i] = args[argIndex++];
+
+    for (size_t i = 0; i < ops.loopSteps.size(); ++i)
+      ops.loopSteps[i] = args[argIndex++];
+
+    if (ops.numTeamsLower)
+      ops.numTeamsLower = args[argIndex++];
+
+    if (ops.numTeamsUpper)
+      ops.numTeamsUpper = args[argIndex++];
+
+    if (ops.numThreads)
+      ops.numThreads = args[argIndex++];
+
+    if (ops.threadLimit)
+      ops.threadLimit = args[argIndex++];
+  }
+
+  /// Update \p clauseOps and \p ivOut with the corresponding host-evaluated
+  /// values and Fortran symbols, respectively, if they have already been
+  /// initialized but not yet applied.
+  ///
+  /// \returns whether an update was performed. If not, these clauses were not
+  ///          evaluated in the host device.
+  bool apply(mlir::omp::LoopNestOperands &clauseOps,
+             llvm::SmallVectorImpl<const semantics::Symbol *> &ivOut) {
+    if (iv.empty() || loopNestApplied) {
+      loopNestApplied = true;
+      return false;
+    }
+
+    loopNestApplied = true;
+    clauseOps.loopLowerBounds = ops.loopLowerBounds;
+    clauseOps.loopUpperBounds = ops.loopUpperBounds;
+    clauseOps.loopSteps = ops.loopSteps;
+    ivOut.append(iv);
+    return true;
+  }
+
+  /// Update \p clauseOps with the corresponding host-evaluated values if they
+  /// have already been initialized but not yet applied.
+  ///
+  /// \returns whether an update was performed. If not, these clauses were not
+  ///          evaluated in the host device.
+  bool apply(mlir::omp::ParallelOperands &clauseOps) {
+    if (!ops.numThreads || parallelApplied) {
+      parallelApplied = true;
+      return false;
+    }
+
+    parallelApplied = true;
+    clauseOps.numThreads = ops.numThreads;
+    return true;
+  }
+
+  /// Update \p clauseOps with the corresponding host-evaluated values if they
+  /// have already been initialized.
+  ///
+  /// \returns whether an update was performed. If not, these clauses were not
+  ///          evaluated in the host device.
+  bool apply(mlir::omp::TeamsOperands &clauseOps) {
----------------
skatrak wrote:

This is because there can't be multiple `teams` in a single `target` region, basically. And the only legal way for a `teams` construct be somewhere nested inside another `teams` construct would be in a reverse-offload situation. In that case, executing this function for the inner `teams` would still not overwrite any of the previously applied values because every time a `target` region is crossed, a new `HostEvalInfo` is added to the stack.

We need to do these checks for `parallel` and `loop_nest` because there can be multiple nesting levels of these in a single target region and we only evaluate the top-level ones in the host. If we didn't check, the `clauseOps` would be overwritten with information from the innermost instance of these constructs.

https://github.com/llvm/llvm-project/pull/116219


More information about the llvm-branch-commits mailing list