[llvm-branch-commits] [flang] [mlir] [Flang][OpenMP] Lowering of host-evaluated clauses (PR #116219)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Nov 14 04:45:53 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

<details>
<summary>Changes</summary>

This patch adds support for lowering OpenMP clauses and expressions attached to constructs nested inside of a target region that need to be evaluated in the host device. This is done through the use of the `OpenMP_HostEvalClause` `omp.target` set of operands and entry block arguments.

When lowering clauses for a target construct, a more involved `processHostEvalClauses()` function is called, which looks at the current and potentially other nested constructs in order to find and lower clauses that need to be processed outside of the `omp.target` operation under construction. This populates an instance of a global structure with the resulting MLIR values.

The resulting list of host-evaluated values is used to initialize the `host_eval` operands when constructing the `omp.target` operation, and then replaced with the corresponding block arguments after creating that operation's region.

Afterwards, while lowering nested operations, those that might potentially be evaluated in the host (e.g. `num_teams`, `thread_limit`, `num_threads` and `collapse`) check first whether there is an active global host-evaluated information structure and whether it holds values referring to these clauses. If that is the case, the stored values (referring to `omp.target` entry block arguments at that stage) are used instead of lowering these clauses again.

---

Patch is 34.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116219.diff


4 Files Affected:

- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+429-29) 
- (added) flang/test/Lower/OpenMP/host-eval.f90 (+138) 
- (added) flang/test/Lower/OpenMP/target-spmd.f90 (+191) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h (+6) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 91f99ba4b0ca55..a206af77a2f51f 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -45,6 +45,19 @@ using namespace Fortran::lower::omp;
 // Code generation helper functions
 //===----------------------------------------------------------------------===//
 
+static void genOMPDispatch(lower::AbstractConverter &converter,
+                           lower::SymMap &symTable,
+                           semantics::SemanticsContext &semaCtx,
+                           lower::pft::Evaluation &eval, mlir::Location loc,
+                           const ConstructQueue &queue,
+                           ConstructQueue::const_iterator item);
+
+static void processHostEvalClauses(lower::AbstractConverter &converter,
+                                   semantics::SemanticsContext &semaCtx,
+                                   lower::StatementContext &stmtCtx,
+                                   lower::pft::Evaluation &eval,
+                                   mlir::Location loc);
+
 namespace {
 /// Structure holding the information needed to create and bind entry block
 /// arguments associated to a single clause.
@@ -63,6 +76,7 @@ struct EntryBlockArgsEntry {
 /// Structure holding the information needed to create and bind entry block
 /// arguments associated to all clauses that can define them.
 struct EntryBlockArgs {
+  llvm::ArrayRef<mlir::Value> hostEvalVars;
   EntryBlockArgsEntry inReduction;
   EntryBlockArgsEntry map;
   EntryBlockArgsEntry priv;
@@ -85,18 +99,146 @@ struct EntryBlockArgs {
 
   auto getVars() const {
     return llvm::concat<const mlir::Value>(
-        inReduction.vars, map.vars, priv.vars, reduction.vars,
+        hostEvalVars, inReduction.vars, map.vars, priv.vars, reduction.vars,
         taskReduction.vars, useDeviceAddr.vars, useDevicePtr.vars);
   }
 };
+
+/// Structure holding information that is needed to pass host-evaluated
+/// information to later lowering stages.
+class HostEvalInfo {
+public:
+  // Allow this function access to private members in order to initialize them.
+  friend void ::processHostEvalClauses(lower::AbstractConverter &,
+                                       semantics::SemanticsContext &,
+                                       lower::StatementContext &,
+                                       lower::pft::Evaluation &,
+                                       mlir::Location);
+
+  /// Fill \c vars with values stored in \c ops.
+  ///
+  /// The order in which values are stored matches the one expected by \see
+  /// bindOperands().
+  void collectValues(llvm::SmallVectorImpl<mlir::Value> &vars) const {
+    vars.append(ops.loopLowerBounds);
+    vars.append(ops.loopUpperBounds);
+    vars.append(ops.loopSteps);
+
+    if (ops.numTeamsLower)
+      vars.push_back(ops.numTeamsLower);
+
+    if (ops.numTeamsUpper)
+      vars.push_back(ops.numTeamsUpper);
+
+    if (ops.numThreads)
+      vars.push_back(ops.numThreads);
+
+    if (ops.threadLimit)
+      vars.push_back(ops.threadLimit);
+  }
+
+  /// Update \c ops, replacing all values with the corresponding block argument
+  /// in \c args.
+  ///
+  /// The order in which values are stored in \c args is the same as the one
+  /// used by \see collectValues().
+  void bindOperands(llvm::ArrayRef<mlir::BlockArgument> args) {
+    assert(args.size() ==
+               ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
+                   ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
+                   (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
+                   (ops.threadLimit ? 1 : 0) &&
+           "invalid block argument list");
+    int argIndex = 0;
+    for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i)
+      ops.loopLowerBounds[i] = args[argIndex++];
+
+    for (size_t i = 0; i < ops.loopUpperBounds.size(); ++i)
+      ops.loopUpperBounds[i] = args[argIndex++];
+
+    for (size_t i = 0; i < ops.loopSteps.size(); ++i)
+      ops.loopSteps[i] = args[argIndex++];
+
+    if (ops.numTeamsLower)
+      ops.numTeamsLower = args[argIndex++];
+
+    if (ops.numTeamsUpper)
+      ops.numTeamsUpper = args[argIndex++];
+
+    if (ops.numThreads)
+      ops.numThreads = args[argIndex++];
+
+    if (ops.threadLimit)
+      ops.threadLimit = args[argIndex++];
+  }
+
+  /// Update \p clauseOps and \p ivOut with the corresponding host-evaluated
+  /// values and Fortran symbols, respectively, if they have already been
+  /// initialized but not yet applied.
+  ///
+  /// \returns whether an update was performed. If not, these clauses were not
+  ///          evaluated in the host device.
+  bool apply(mlir::omp::LoopNestOperands &clauseOps,
+             llvm::SmallVectorImpl<const semantics::Symbol *> &ivOut) {
+    if (iv.empty() || loopNestApplied) {
+      loopNestApplied = true;
+      return false;
+    }
+
+    loopNestApplied = true;
+    clauseOps.loopLowerBounds = ops.loopLowerBounds;
+    clauseOps.loopUpperBounds = ops.loopUpperBounds;
+    clauseOps.loopSteps = ops.loopSteps;
+    ivOut.append(iv);
+    return true;
+  }
+
+  /// Update \p clauseOps with the corresponding host-evaluated values if they
+  /// have already been initialized but not yet applied.
+  ///
+  /// \returns whether an update was performed. If not, these clauses were not
+  ///          evaluated in the host device.
+  bool apply(mlir::omp::ParallelOperands &clauseOps) {
+    if (!ops.numThreads || parallelApplied) {
+      parallelApplied = true;
+      return false;
+    }
+
+    parallelApplied = true;
+    clauseOps.numThreads = ops.numThreads;
+    return true;
+  }
+
+  /// Update \p clauseOps with the corresponding host-evaluated values if they
+  /// have already been initialized.
+  ///
+  /// \returns whether an update was performed. If not, these clauses were not
+  ///          evaluated in the host device.
+  bool apply(mlir::omp::TeamsOperands &clauseOps) {
+    if (!ops.numTeamsLower && !ops.numTeamsUpper && !ops.threadLimit)
+      return false;
+
+    clauseOps.numTeamsLower = ops.numTeamsLower;
+    clauseOps.numTeamsUpper = ops.numTeamsUpper;
+    clauseOps.threadLimit = ops.threadLimit;
+    return true;
+  }
+
+private:
+  mlir::omp::HostEvaluatedOperands ops;
+  llvm::SmallVector<const semantics::Symbol *> iv;
+  bool loopNestApplied = false, parallelApplied = false;
+};
 } // namespace
 
-static void genOMPDispatch(lower::AbstractConverter &converter,
-                           lower::SymMap &symTable,
-                           semantics::SemanticsContext &semaCtx,
-                           lower::pft::Evaluation &eval, mlir::Location loc,
-                           const ConstructQueue &queue,
-                           ConstructQueue::const_iterator item);
+/// Stack of \see HostEvalInfo to represent the current nest of \c omp.target
+/// operations being created.
+///
+/// The current implementation prevents nested 'target' regions from breaking
+/// the handling of the outer region by keeping a stack of information
+/// structures, but it will probably still require some further work to support
+/// reverse offloading.
+static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo;
 
 /// Bind symbols to their corresponding entry block arguments.
 ///
@@ -219,6 +361,8 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter,
   };
 
   // Process in clause name alphabetical order to match block arguments order.
+  // Do not bind host_eval variables because they cannot be used inside of the
+  // corresponding region, except for very specific cases handled separately.
   bindPrivateLike(args.inReduction.syms, args.inReduction.vars,
                   op.getInReductionBlockArgs());
   bindMapLike(args.map.syms, op.getMapBlockArgs());
@@ -256,6 +400,246 @@ extractMappedBaseValues(llvm::ArrayRef<mlir::Value> vars,
   });
 }
 
+/// Get the directive enumeration value corresponding to the given OpenMP
+/// construct PFT node.
+llvm::omp::Directive
+extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) {
+  return common::visit(
+      common::visitors{
+          [](const parser::OpenMPAllocatorsConstruct &c) {
+            return llvm::omp::OMPD_allocators;
+          },
+          [](const parser::OpenMPAtomicConstruct &c) {
+            return llvm::omp::OMPD_atomic;
+          },
+          [](const parser::OpenMPBlockConstruct &c) {
+            return std::get<parser::OmpBlockDirective>(
+                       std::get<parser::OmpBeginBlockDirective>(c.t).t)
+                .v;
+          },
+          [](const parser::OpenMPCriticalConstruct &c) {
+            return llvm::omp::OMPD_critical;
+          },
+          [](const parser::OpenMPDeclarativeAllocate &c) {
+            return llvm::omp::OMPD_allocate;
+          },
+          [](const parser::OpenMPExecutableAllocate &c) {
+            return llvm::omp::OMPD_allocate;
+          },
+          [](const parser::OpenMPLoopConstruct &c) {
+            return std::get<parser::OmpLoopDirective>(
+                       std::get<parser::OmpBeginLoopDirective>(c.t).t)
+                .v;
+          },
+          [](const parser::OpenMPSectionConstruct &c) {
+            return llvm::omp::OMPD_section;
+          },
+          [](const parser::OpenMPSectionsConstruct &c) {
+            return std::get<parser::OmpSectionsDirective>(
+                       std::get<parser::OmpBeginSectionsDirective>(c.t).t)
+                .v;
+          },
+          [](const parser::OpenMPStandaloneConstruct &c) {
+            return common::visit(
+                common::visitors{
+                    [](const parser::OpenMPSimpleStandaloneConstruct &c) {
+                      return std::get<parser::OmpSimpleStandaloneDirective>(c.t)
+                          .v;
+                    },
+                    [](const parser::OpenMPFlushConstruct &c) {
+                      return llvm::omp::OMPD_flush;
+                    },
+                    [](const parser::OpenMPCancelConstruct &c) {
+                      return llvm::omp::OMPD_cancel;
+                    },
+                    [](const parser::OpenMPCancellationPointConstruct &c) {
+                      return llvm::omp::OMPD_cancellation_point;
+                    },
+                    [](const parser::OpenMPDepobjConstruct &c) {
+                      return llvm::omp::OMPD_depobj;
+                    }},
+                c.u);
+          }},
+      ompConstruct.u);
+}
+
+/// Populate the global \see hostEvalInfo after processing clauses for the given
+/// \p eval OpenMP target construct, or nested constructs, if these must be
+/// evaluated outside of the target region per the spec.
+///
+/// In particular, this will ensure that in 'target teams' and equivalent nested
+/// constructs, the \c thread_limit and \c num_teams clauses will be evaluated
+/// in the host. Additionally, loop bounds, steps and the \c num_threads clause
+/// will also be evaluated in the host if a target SPMD construct is detected
+/// (i.e. 'target teams distribute parallel do [simd]' or equivalent nesting).
+///
+/// The result, stored as a global, is intended to be used to populate the \c
+/// host_eval operands of the associated \c omp.target operation, and also to be
+/// checked and used by later lowering steps to populate the corresponding
+/// operands of the \c omp.teams, \c omp.parallel or \c omp.loop_nest
+/// operations.
+static void processHostEvalClauses(lower::AbstractConverter &converter,
+                                   semantics::SemanticsContext &semaCtx,
+                                   lower::StatementContext &stmtCtx,
+                                   lower::pft::Evaluation &eval,
+                                   mlir::Location loc) {
+  // Obtain the list of clauses of the given OpenMP block or loop construct
+  // evaluation. Other evaluations passed to this lambda keep `clauses`
+  // unchanged.
+  auto extractClauses = [&semaCtx](lower::pft::Evaluation &eval,
+                                   List<Clause> &clauses) {
+    const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+    if (!ompEval)
+      return;
+
+    const parser::OmpClauseList *beginClauseList = nullptr;
+    const parser::OmpClauseList *endClauseList = nullptr;
+    common::visit(
+        common::visitors{
+            [&](const parser::OpenMPBlockConstruct &ompConstruct) {
+              const auto &beginDirective =
+                  std::get<parser::OmpBeginBlockDirective>(ompConstruct.t);
+              beginClauseList =
+                  &std::get<parser::OmpClauseList>(beginDirective.t);
+              endClauseList = &std::get<parser::OmpClauseList>(
+                  std::get<parser::OmpEndBlockDirective>(ompConstruct.t).t);
+            },
+            [&](const parser::OpenMPLoopConstruct &ompConstruct) {
+              const auto &beginDirective =
+                  std::get<parser::OmpBeginLoopDirective>(ompConstruct.t);
+              beginClauseList =
+                  &std::get<parser::OmpClauseList>(beginDirective.t);
+
+              if (auto &endDirective =
+                      std::get<std::optional<parser::OmpEndLoopDirective>>(
+                          ompConstruct.t))
+                endClauseList =
+                    &std::get<parser::OmpClauseList>(endDirective->t);
+            },
+            [&](const auto &) {}},
+        ompEval->u);
+
+    assert(beginClauseList && "expected begin directive");
+    clauses.append(makeClauses(*beginClauseList, semaCtx));
+
+    if (endClauseList)
+      clauses.append(makeClauses(*endClauseList, semaCtx));
+  };
+
+  // Return the directive that is immediately nested inside of the given
+  // `parent` evaluation, if it is its only non-end-statement nested evaluation
+  // and it represents an OpenMP construct.
+  auto extractOnlyOmpNestedDir = [](lower::pft::Evaluation &parent)
+      -> std::optional<llvm::omp::Directive> {
+    if (!parent.hasNestedEvaluations())
+      return std::nullopt;
+
+    llvm::omp::Directive dir;
+    auto &nested = parent.getFirstNestedEvaluation();
+    if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>())
+      dir = extractOmpDirective(*ompEval);
+    else
+      return std::nullopt;
+
+    for (auto &sibling : parent.getNestedEvaluations())
+      if (&sibling != &nested && !sibling.isEndStmt())
+        return std::nullopt;
+
+    return dir;
+  };
+
+  // Process the given evaluation assuming it's part of a 'target' construct or
+  // captured by one, and store results in the global `hostEvalInfo`.
+  std::function<void(lower::pft::Evaluation &, const List<Clause> &)>
+      processEval;
+  processEval = [&](lower::pft::Evaluation &eval, const List<Clause> &clauses) {
+    using namespace llvm::omp;
+    ClauseProcessor cp(converter, semaCtx, clauses);
+
+    // Call `processEval` recursively with the immediately nested evaluation and
+    // its corresponding clauses if there is a single nested evaluation
+    // representing an OpenMP directive that passes the given test.
+    auto processSingleNestedIf = [&](llvm::function_ref<bool(Directive)> test) {
+      std::optional<Directive> nestedDir = extractOnlyOmpNestedDir(eval);
+      if (!nestedDir || !test(*nestedDir))
+        return;
+
+      lower::pft::Evaluation &nestedEval = eval.getFirstNestedEvaluation();
+      List<lower::omp::Clause> nestedClauses;
+      extractClauses(nestedEval, nestedClauses);
+      processEval(nestedEval, nestedClauses);
+    };
+
+    const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+    if (!ompEval)
+      return;
+
+    HostEvalInfo &hostInfo = hostEvalInfo.back();
+
+    switch (extractOmpDirective(*ompEval)) {
+    // Cases where 'teams' and target SPMD clauses might be present.
+    case OMPD_teams_distribute_parallel_do:
+    case OMPD_teams_distribute_parallel_do_simd:
+      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams_distribute_parallel_do:
+    case OMPD_target_teams_distribute_parallel_do_simd:
+      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_distribute_parallel_do:
+    case OMPD_distribute_parallel_do_simd:
+      cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+      cp.processNumThreads(stmtCtx, hostInfo.ops);
+      break;
+
+    // Cases where 'teams' clauses might be present, and target SPMD is
+    // possible by looking at nested evaluations.
+    case OMPD_teams:
+      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams:
+      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      processSingleNestedIf([](Directive nestedDir) {
+        return nestedDir == OMPD_distribute_parallel_do ||
+               nestedDir == OMPD_distribute_parallel_do_simd;
+      });
+      break;
+
+    // Cases where only 'teams' host-evaluated clauses might be present.
+    case OMPD_teams_distribute:
+    case OMPD_teams_distribute_simd:
+      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams_distribute:
+    case OMPD_target_teams_distribute_simd:
+      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      break;
+
+    // Standalone 'target' case.
+    case OMPD_target: {
+      processSingleNestedIf(
+          [](Directive nestedDir) { return topTeamsSet.test(nestedDir); });
+      break;
+    }
+    default:
+      break;
+    }
+  };
+
+  assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
+
+  const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+  assert(ompEval &&
+         llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
+         "expected TARGET construct evaluation");
+
+  // Use the whole list of clauses passed to the construct here, rather than the
+  // ones only applied to omp.target.
+  List<lower::omp::Clause> clauses;
+  extractClauses(eval, clauses);
+  processEval(eval, clauses);
+}
+
 static lower::pft::Evaluation *
 getCollapsedLoopEval(lower::pft::Evaluation &eval, int collapseValue) {
   // Return the Evaluation of the innermost collapsed loop, or the current one
@@ -638,11 +1022,11 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter,
 
   llvm::SmallVector<mlir::Type> types;
   llvm::SmallVector<mlir::Location> locs;
-  unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() +
-                     args.priv.vars.size() + args.reduction.vars.size() +
-                     args.taskReduction.vars.size() +
-                     args.useDeviceAddr.vars.size() +
-                     args.useDevicePtr.vars.size();
+  unsigned numVars =
+      args.hostEvalVars.size() + args.inReduction.vars.size() +
+      args.map.vars.size() + args.priv.vars.size() +
+      args.reduction.vars.size() + args.taskReduction.vars.size() +
+      args.useDeviceAddr.vars.size() + args.useDevicePtr.vars.size();
   types.reserve(numVars);
   locs.reserve(numVars);
 
@@ -655,6 +1039,7 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter,
 
   // Populate block arguments in clause name alphabetical order to match
   // expected order by the BlockArgOpenMPOpInterface.
+  extractTypeLoc(args.hostEvalVars);
   extractTypeLoc(args.inReduction.vars);
   extractTypeLoc(args.map.vars);
   extractTypeLoc(args.priv.vars);
@@ -991,12 +1376,15 @@ static void genBodyOfTargetOp(
     mlir::omp::TargetOp &targetOp, const EntryBlockArgs &args,
     const mlir::Location &currentLocation, const ConstructQueue &queue,
     ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
+  assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
+
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp);
 
   mlir::Region &region = targetOp.getRegion();
   mlir::Block *entryBlock = genEntryBlock(converter, args, region);
   bindEntryBlockArgs(converter, targetOp, args);
+  hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs());
 
   // Check if cloning the bounds introduced any dependency on the outer region.
   // If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1172,7 +1560,10 @@ genLoopNestClauses(lower::AbstractConverter &converter,
                    mlir::Location loc, ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/116219


More information about the llvm-branch-commits mailing list