[llvm-branch-commits] [flang] [mlir] [Flang][OpenMP] Lowering of host-evaluated clauses (PR #116219)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Nov 27 04:31:49 PST 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/116219

>From 581746291e48096b40b4b059cd52d44c9629f2dd Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 14 Nov 2024 12:24:15 +0000
Subject: [PATCH] [Flang][OpenMP] Lowering of host-evaluated clauses

This patch adds support for lowering OpenMP clauses and expressions attached to
constructs nested inside of a target region that need to be evaluated in the
host device. This is done through the use of the `OpenMP_HostEvalClause`
`omp.target` set of operands and entry block arguments.

When lowering clauses for a target construct, a more involved
`processHostEvalClauses()` function is called, which looks at the current and
potentially other nested constructs in order to find and lower clauses that
need to be processed outside of the `omp.target` operation under construction.
This populates an instance of a global structure with the resulting MLIR
values.

The resulting list of host-evaluated values is used to initialize the
`host_eval` operands when constructing the `omp.target` operation, and then
replaced with the corresponding block arguments after creating that operation's
region.

Afterwards, while lowering nested operations, those that might potentially be
evaluated in the host (e.g. `num_teams`, `thread_limit`, `num_threads` and
`collapse`) check first whether there is an active global host-evaluated
information structure and whether it holds values referring to these clauses.
If that is the case, the stored values (referring to `omp.target` entry block
arguments at that stage) are used instead of lowering clauses again.
---
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 463 ++++++++++++++++--
 flang/test/Lower/OpenMP/host-eval.f90         | 157 ++++++
 flang/test/Lower/OpenMP/target-spmd.f90       | 191 ++++++++
 .../Dialect/OpenMP/OpenMPClauseOperands.h     |   6 +
 4 files changed, 788 insertions(+), 29 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/host-eval.f90
 create mode 100644 flang/test/Lower/OpenMP/target-spmd.f90

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 91f99ba4b0ca55..a8a203e531b1e9 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -45,6 +45,19 @@ using namespace Fortran::lower::omp;
 // Code generation helper functions
 //===----------------------------------------------------------------------===//
 
+static void genOMPDispatch(lower::AbstractConverter &converter,
+                           lower::SymMap &symTable,
+                           semantics::SemanticsContext &semaCtx,
+                           lower::pft::Evaluation &eval, mlir::Location loc,
+                           const ConstructQueue &queue,
+                           ConstructQueue::const_iterator item);
+
+static void processHostEvalClauses(lower::AbstractConverter &converter,
+                                   semantics::SemanticsContext &semaCtx,
+                                   lower::StatementContext &stmtCtx,
+                                   lower::pft::Evaluation &eval,
+                                   mlir::Location loc);
+
 namespace {
 /// Structure holding the information needed to create and bind entry block
 /// arguments associated to a single clause.
@@ -63,6 +76,7 @@ struct EntryBlockArgsEntry {
 /// Structure holding the information needed to create and bind entry block
 /// arguments associated to all clauses that can define them.
 struct EntryBlockArgs {
+  llvm::ArrayRef<mlir::Value> hostEvalVars;
   EntryBlockArgsEntry inReduction;
   EntryBlockArgsEntry map;
   EntryBlockArgsEntry priv;
@@ -85,18 +99,146 @@ struct EntryBlockArgs {
 
   auto getVars() const {
     return llvm::concat<const mlir::Value>(
-        inReduction.vars, map.vars, priv.vars, reduction.vars,
+        hostEvalVars, inReduction.vars, map.vars, priv.vars, reduction.vars,
         taskReduction.vars, useDeviceAddr.vars, useDevicePtr.vars);
   }
 };
+
+/// Structure holding information that is needed to pass host-evaluated
+/// information to later lowering stages.
+class HostEvalInfo {
+public:
+  // Allow this function access to private members in order to initialize them.
+  friend void ::processHostEvalClauses(lower::AbstractConverter &,
+                                       semantics::SemanticsContext &,
+                                       lower::StatementContext &,
+                                       lower::pft::Evaluation &,
+                                       mlir::Location);
+
+  /// Fill \c vars with values stored in \c ops.
+  ///
+  /// The order in which values are stored matches the one expected by \see
+  /// bindOperands().
+  void collectValues(llvm::SmallVectorImpl<mlir::Value> &vars) const {
+    vars.append(ops.loopLowerBounds);
+    vars.append(ops.loopUpperBounds);
+    vars.append(ops.loopSteps);
+
+    if (ops.numTeamsLower)
+      vars.push_back(ops.numTeamsLower);
+
+    if (ops.numTeamsUpper)
+      vars.push_back(ops.numTeamsUpper);
+
+    if (ops.numThreads)
+      vars.push_back(ops.numThreads);
+
+    if (ops.threadLimit)
+      vars.push_back(ops.threadLimit);
+  }
+
+  /// Update \c ops, replacing all values with the corresponding block argument
+  /// in \c args.
+  ///
+  /// The order in which values are stored in \c args is the same as the one
+  /// used by \see collectValues().
+  void bindOperands(llvm::ArrayRef<mlir::BlockArgument> args) {
+    assert(args.size() ==
+               ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
+                   ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
+                   (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
+                   (ops.threadLimit ? 1 : 0) &&
+           "invalid block argument list");
+    int argIndex = 0;
+    for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i)
+      ops.loopLowerBounds[i] = args[argIndex++];
+
+    for (size_t i = 0; i < ops.loopUpperBounds.size(); ++i)
+      ops.loopUpperBounds[i] = args[argIndex++];
+
+    for (size_t i = 0; i < ops.loopSteps.size(); ++i)
+      ops.loopSteps[i] = args[argIndex++];
+
+    if (ops.numTeamsLower)
+      ops.numTeamsLower = args[argIndex++];
+
+    if (ops.numTeamsUpper)
+      ops.numTeamsUpper = args[argIndex++];
+
+    if (ops.numThreads)
+      ops.numThreads = args[argIndex++];
+
+    if (ops.threadLimit)
+      ops.threadLimit = args[argIndex++];
+  }
+
+  /// Update \p clauseOps and \p ivOut with the corresponding host-evaluated
+  /// values and Fortran symbols, respectively, if they have already been
+  /// initialized but not yet applied.
+  ///
+  /// \returns whether an update was performed. If not, these clauses were not
+  ///          evaluated in the host device.
+  bool apply(mlir::omp::LoopNestOperands &clauseOps,
+             llvm::SmallVectorImpl<const semantics::Symbol *> &ivOut) {
+    if (iv.empty() || loopNestApplied) {
+      loopNestApplied = true;
+      return false;
+    }
+
+    loopNestApplied = true;
+    clauseOps.loopLowerBounds = ops.loopLowerBounds;
+    clauseOps.loopUpperBounds = ops.loopUpperBounds;
+    clauseOps.loopSteps = ops.loopSteps;
+    ivOut.append(iv);
+    return true;
+  }
+
+  /// Update \p clauseOps with the corresponding host-evaluated values if they
+  /// have already been initialized but not yet applied.
+  ///
+  /// \returns whether an update was performed. If not, these clauses were not
+  ///          evaluated in the host device.
+  bool apply(mlir::omp::ParallelOperands &clauseOps) {
+    if (!ops.numThreads || parallelApplied) {
+      parallelApplied = true;
+      return false;
+    }
+
+    parallelApplied = true;
+    clauseOps.numThreads = ops.numThreads;
+    return true;
+  }
+
+  /// Update \p clauseOps with the corresponding host-evaluated values if they
+  /// have already been initialized.
+  ///
+  /// \returns whether an update was performed. If not, these clauses were not
+  ///          evaluated in the host device.
+  bool apply(mlir::omp::TeamsOperands &clauseOps) {
+    if (!ops.numTeamsLower && !ops.numTeamsUpper && !ops.threadLimit)
+      return false;
+
+    clauseOps.numTeamsLower = ops.numTeamsLower;
+    clauseOps.numTeamsUpper = ops.numTeamsUpper;
+    clauseOps.threadLimit = ops.threadLimit;
+    return true;
+  }
+
+private:
+  mlir::omp::HostEvaluatedOperands ops;
+  llvm::SmallVector<const semantics::Symbol *> iv;
+  bool loopNestApplied = false, parallelApplied = false;
+};
 } // namespace
 
-static void genOMPDispatch(lower::AbstractConverter &converter,
-                           lower::SymMap &symTable,
-                           semantics::SemanticsContext &semaCtx,
-                           lower::pft::Evaluation &eval, mlir::Location loc,
-                           const ConstructQueue &queue,
-                           ConstructQueue::const_iterator item);
+/// Stack of \see HostEvalInfo to represent the current nest of \c omp.target
+/// operations being created.
+///
+/// The current implementation prevents nested 'target' regions from breaking
+/// the handling of the outer region by keeping a stack of information
+/// structures, but it will probably still require some further work to support
+/// reverse offloading.
+static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo;
 
 /// Bind symbols to their corresponding entry block arguments.
 ///
@@ -219,6 +361,8 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter,
   };
 
   // Process in clause name alphabetical order to match block arguments order.
+  // Do not bind host_eval variables because they cannot be used inside of the
+  // corresponding region, except for very specific cases handled separately.
   bindPrivateLike(args.inReduction.syms, args.inReduction.vars,
                   op.getInReductionBlockArgs());
   bindMapLike(args.map.syms, op.getMapBlockArgs());
@@ -256,6 +400,246 @@ extractMappedBaseValues(llvm::ArrayRef<mlir::Value> vars,
   });
 }
 
+/// Get the directive enumeration value corresponding to the given OpenMP
+/// construct PFT node.
+llvm::omp::Directive
+extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) {
+  return common::visit(
+      common::visitors{
+          [](const parser::OpenMPAllocatorsConstruct &c) {
+            return llvm::omp::OMPD_allocators;
+          },
+          [](const parser::OpenMPAtomicConstruct &c) {
+            return llvm::omp::OMPD_atomic;
+          },
+          [](const parser::OpenMPBlockConstruct &c) {
+            return std::get<parser::OmpBlockDirective>(
+                       std::get<parser::OmpBeginBlockDirective>(c.t).t)
+                .v;
+          },
+          [](const parser::OpenMPCriticalConstruct &c) {
+            return llvm::omp::OMPD_critical;
+          },
+          [](const parser::OpenMPDeclarativeAllocate &c) {
+            return llvm::omp::OMPD_allocate;
+          },
+          [](const parser::OpenMPExecutableAllocate &c) {
+            return llvm::omp::OMPD_allocate;
+          },
+          [](const parser::OpenMPLoopConstruct &c) {
+            return std::get<parser::OmpLoopDirective>(
+                       std::get<parser::OmpBeginLoopDirective>(c.t).t)
+                .v;
+          },
+          [](const parser::OpenMPSectionConstruct &c) {
+            return llvm::omp::OMPD_section;
+          },
+          [](const parser::OpenMPSectionsConstruct &c) {
+            return std::get<parser::OmpSectionsDirective>(
+                       std::get<parser::OmpBeginSectionsDirective>(c.t).t)
+                .v;
+          },
+          [](const parser::OpenMPStandaloneConstruct &c) {
+            return common::visit(
+                common::visitors{
+                    [](const parser::OpenMPSimpleStandaloneConstruct &c) {
+                      return std::get<parser::OmpSimpleStandaloneDirective>(c.t)
+                          .v;
+                    },
+                    [](const parser::OpenMPFlushConstruct &c) {
+                      return llvm::omp::OMPD_flush;
+                    },
+                    [](const parser::OpenMPCancelConstruct &c) {
+                      return llvm::omp::OMPD_cancel;
+                    },
+                    [](const parser::OpenMPCancellationPointConstruct &c) {
+                      return llvm::omp::OMPD_cancellation_point;
+                    },
+                    [](const parser::OpenMPDepobjConstruct &c) {
+                      return llvm::omp::OMPD_depobj;
+                    }},
+                c.u);
+          }},
+      ompConstruct.u);
+}
+
+/// Populate the global \see hostEvalInfo after processing clauses for the given
+/// \p eval OpenMP target construct, or nested constructs, if these must be
+/// evaluated outside of the target region per the spec.
+///
+/// In particular, this will ensure that in 'target teams' and equivalent nested
+/// constructs, the \c thread_limit and \c num_teams clauses will be evaluated
+/// in the host. Additionally, loop bounds, steps and the \c num_threads clause
+/// will also be evaluated in the host if a target SPMD construct is detected
+/// (i.e. 'target teams distribute parallel do [simd]' or equivalent nesting).
+///
+/// The result, stored as a global, is intended to be used to populate the \c
+/// host_eval operands of the associated \c omp.target operation, and also to be
+/// checked and used by later lowering steps to populate the corresponding
+/// operands of the \c omp.teams, \c omp.parallel or \c omp.loop_nest
+/// operations.
+static void processHostEvalClauses(lower::AbstractConverter &converter,
+                                   semantics::SemanticsContext &semaCtx,
+                                   lower::StatementContext &stmtCtx,
+                                   lower::pft::Evaluation &eval,
+                                   mlir::Location loc) {
+  // Obtain the list of clauses of the given OpenMP block or loop construct
+  // evaluation. Other evaluations passed to this lambda keep `clauses`
+  // unchanged.
+  auto extractClauses = [&semaCtx](lower::pft::Evaluation &eval,
+                                   List<Clause> &clauses) {
+    const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+    if (!ompEval)
+      return;
+
+    const parser::OmpClauseList *beginClauseList = nullptr;
+    const parser::OmpClauseList *endClauseList = nullptr;
+    common::visit(
+        common::visitors{
+            [&](const parser::OpenMPBlockConstruct &ompConstruct) {
+              const auto &beginDirective =
+                  std::get<parser::OmpBeginBlockDirective>(ompConstruct.t);
+              beginClauseList =
+                  &std::get<parser::OmpClauseList>(beginDirective.t);
+              endClauseList = &std::get<parser::OmpClauseList>(
+                  std::get<parser::OmpEndBlockDirective>(ompConstruct.t).t);
+            },
+            [&](const parser::OpenMPLoopConstruct &ompConstruct) {
+              const auto &beginDirective =
+                  std::get<parser::OmpBeginLoopDirective>(ompConstruct.t);
+              beginClauseList =
+                  &std::get<parser::OmpClauseList>(beginDirective.t);
+
+              if (auto &endDirective =
+                      std::get<std::optional<parser::OmpEndLoopDirective>>(
+                          ompConstruct.t))
+                endClauseList =
+                    &std::get<parser::OmpClauseList>(endDirective->t);
+            },
+            [&](const auto &) {}},
+        ompEval->u);
+
+    assert(beginClauseList && "expected begin directive");
+    clauses.append(makeClauses(*beginClauseList, semaCtx));
+
+    if (endClauseList)
+      clauses.append(makeClauses(*endClauseList, semaCtx));
+  };
+
+  // Return the directive that is immediately nested inside of the given
+  // `parent` evaluation, if it is its only non-end-statement nested evaluation
+  // and it represents an OpenMP construct.
+  auto extractOnlyOmpNestedDir = [](lower::pft::Evaluation &parent)
+      -> std::optional<llvm::omp::Directive> {
+    if (!parent.hasNestedEvaluations())
+      return std::nullopt;
+
+    llvm::omp::Directive dir;
+    auto &nested = parent.getFirstNestedEvaluation();
+    if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>())
+      dir = extractOmpDirective(*ompEval);
+    else
+      return std::nullopt;
+
+    for (auto &sibling : parent.getNestedEvaluations())
+      if (&sibling != &nested && !sibling.isEndStmt())
+        return std::nullopt;
+
+    return dir;
+  };
+
+  // Process the given evaluation assuming it's part of a 'target' construct or
+  // captured by one, and store results in the global `hostEvalInfo`.
+  std::function<void(lower::pft::Evaluation &, const List<Clause> &)>
+      processEval;
+  processEval = [&](lower::pft::Evaluation &eval, const List<Clause> &clauses) {
+    using namespace llvm::omp;
+    ClauseProcessor cp(converter, semaCtx, clauses);
+
+    // Call `processEval` recursively with the immediately nested evaluation and
+    // its corresponding clauses if there is a single nested evaluation
+    // representing an OpenMP directive that passes the given test.
+    auto processSingleNestedIf = [&](llvm::function_ref<bool(Directive)> test) {
+      std::optional<Directive> nestedDir = extractOnlyOmpNestedDir(eval);
+      if (!nestedDir || !test(*nestedDir))
+        return;
+
+      lower::pft::Evaluation &nestedEval = eval.getFirstNestedEvaluation();
+      List<lower::omp::Clause> nestedClauses;
+      extractClauses(nestedEval, nestedClauses);
+      processEval(nestedEval, nestedClauses);
+    };
+
+    const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+    if (!ompEval)
+      return;
+
+    HostEvalInfo &hostInfo = hostEvalInfo.back();
+
+    switch (extractOmpDirective(*ompEval)) {
+    // Cases where 'teams' and target SPMD clauses might be present.
+    case OMPD_teams_distribute_parallel_do:
+    case OMPD_teams_distribute_parallel_do_simd:
+      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams_distribute_parallel_do:
+    case OMPD_target_teams_distribute_parallel_do_simd:
+      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_distribute_parallel_do:
+    case OMPD_distribute_parallel_do_simd:
+      cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+      cp.processNumThreads(stmtCtx, hostInfo.ops);
+      break;
+
+    // Cases where 'teams' clauses might be present, and target SPMD is
+    // possible by looking at nested evaluations.
+    case OMPD_teams:
+      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams:
+      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      processSingleNestedIf([](Directive nestedDir) {
+        return nestedDir == OMPD_distribute_parallel_do ||
+               nestedDir == OMPD_distribute_parallel_do_simd;
+      });
+      break;
+
+    // Cases where only 'teams' host-evaluated clauses might be present.
+    case OMPD_teams_distribute:
+    case OMPD_teams_distribute_simd:
+      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams_distribute:
+    case OMPD_target_teams_distribute_simd:
+      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      break;
+
+    // Standalone 'target' case.
+    case OMPD_target: {
+      processSingleNestedIf(
+          [](Directive nestedDir) { return topTeamsSet.test(nestedDir); });
+      break;
+    }
+    default:
+      break;
+    }
+  };
+
+  assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
+
+  const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+  assert(ompEval &&
+         llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
+         "expected TARGET construct evaluation");
+
+  // Use the whole list of clauses passed to the construct here, rather than the
+  // ones only applied to omp.target.
+  List<lower::omp::Clause> clauses;
+  extractClauses(eval, clauses);
+  processEval(eval, clauses);
+}
+
 static lower::pft::Evaluation *
 getCollapsedLoopEval(lower::pft::Evaluation &eval, int collapseValue) {
   // Return the Evaluation of the innermost collapsed loop, or the current one
@@ -638,11 +1022,11 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter,
 
   llvm::SmallVector<mlir::Type> types;
   llvm::SmallVector<mlir::Location> locs;
-  unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() +
-                     args.priv.vars.size() + args.reduction.vars.size() +
-                     args.taskReduction.vars.size() +
-                     args.useDeviceAddr.vars.size() +
-                     args.useDevicePtr.vars.size();
+  unsigned numVars =
+      args.hostEvalVars.size() + args.inReduction.vars.size() +
+      args.map.vars.size() + args.priv.vars.size() +
+      args.reduction.vars.size() + args.taskReduction.vars.size() +
+      args.useDeviceAddr.vars.size() + args.useDevicePtr.vars.size();
   types.reserve(numVars);
   locs.reserve(numVars);
 
@@ -655,6 +1039,7 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter,
 
   // Populate block arguments in clause name alphabetical order to match
   // expected order by the BlockArgOpenMPOpInterface.
+  extractTypeLoc(args.hostEvalVars);
   extractTypeLoc(args.inReduction.vars);
   extractTypeLoc(args.map.vars);
   extractTypeLoc(args.priv.vars);
@@ -997,6 +1382,8 @@ static void genBodyOfTargetOp(
   mlir::Region &region = targetOp.getRegion();
   mlir::Block *entryBlock = genEntryBlock(converter, args, region);
   bindEntryBlockArgs(converter, targetOp, args);
+  if (!hostEvalInfo.empty())
+    hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs());
 
   // Check if cloning the bounds introduced any dependency on the outer region.
   // If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1172,7 +1559,10 @@ genLoopNestClauses(lower::AbstractConverter &converter,
                    mlir::Location loc, mlir::omp::LoopNestOperands &clauseOps,
                    llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
   ClauseProcessor cp(converter, semaCtx, clauses);
-  cp.processCollapse(loc, eval, clauseOps, iv);
+
+  if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv))
+    cp.processCollapse(loc, eval, clauseOps, iv);
+
   clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
 }
 
@@ -1202,7 +1592,10 @@ static void genParallelClauses(
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processAllocate(clauseOps);
   cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
-  cp.processNumThreads(stmtCtx, clauseOps);
+
+  if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps))
+    cp.processNumThreads(stmtCtx, clauseOps);
+
   cp.processProcBind(clauseOps);
   cp.processReduction(loc, clauseOps, reductionSyms);
 }
@@ -1249,8 +1642,8 @@ static void genSingleClauses(lower::AbstractConverter &converter,
 
 static void genTargetClauses(
     lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
-    lower::StatementContext &stmtCtx, const List<Clause> &clauses,
-    mlir::Location loc, bool processHostOnlyClauses,
+    lower::StatementContext &stmtCtx, lower::pft::Evaluation &eval,
+    const List<Clause> &clauses, mlir::Location loc,
     mlir::omp::TargetOperands &clauseOps,
     llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms,
     llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
@@ -1259,13 +1652,15 @@ static void genTargetClauses(
   cp.processDepend(clauseOps);
   cp.processDevice(stmtCtx, clauseOps);
   cp.processHasDeviceAddr(clauseOps, hasDeviceAddrSyms);
+  if (!hostEvalInfo.empty()) {
+    // Only process host_eval if compiling for the host device.
+    processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc);
+    hostEvalInfo.back().collectValues(clauseOps.hostEvalVars);
+  }
   cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
   cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
   cp.processMap(loc, stmtCtx, clauseOps, &mapSyms);
-
-  if (processHostOnlyClauses)
-    cp.processNowait(clauseOps);
-
+  cp.processNowait(clauseOps);
   cp.processThreadLimit(stmtCtx, clauseOps);
 
   cp.processTODO<clause::Allocate, clause::Defaultmap, clause::Firstprivate,
@@ -1367,10 +1762,13 @@ static void genTeamsClauses(lower::AbstractConverter &converter,
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processAllocate(clauseOps);
   cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
-  cp.processNumTeams(stmtCtx, clauseOps);
-  cp.processThreadLimit(stmtCtx, clauseOps);
-  // TODO Support delayed privatization.
 
+  if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) {
+    cp.processNumTeams(stmtCtx, clauseOps);
+    cp.processThreadLimit(stmtCtx, clauseOps);
+  }
+
+  // TODO Support delayed privatization.
   cp.processTODO<clause::Reduction>(loc, llvm::omp::Directive::OMPD_teams);
 }
 
@@ -1709,17 +2107,19 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
             ConstructQueue::const_iterator item) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   lower::StatementContext stmtCtx;
+  bool isTargetDevice =
+      llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp())
+          .getIsTargetDevice();
 
-  bool processHostOnlyClauses =
-      !llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp())
-           .getIsTargetDevice();
+  // Introduce a new host_eval information structure for this target region.
+  if (!isTargetDevice)
+    hostEvalInfo.emplace_back();
 
   mlir::omp::TargetOperands clauseOps;
   llvm::SmallVector<const semantics::Symbol *> mapSyms, isDevicePtrSyms,
       hasDeviceAddrSyms;
-  genTargetClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
-                   processHostOnlyClauses, clauseOps, hasDeviceAddrSyms,
-                   isDevicePtrSyms, mapSyms);
+  genTargetClauses(converter, semaCtx, stmtCtx, eval, item->clauses, loc,
+                   clauseOps, hasDeviceAddrSyms, isDevicePtrSyms, mapSyms);
 
   DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
                            /*shouldCollectPreDeterminedSymbols=*/
@@ -1836,6 +2236,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   extractMappedBaseValues(clauseOps.mapVars, mapBaseValues);
 
   EntryBlockArgs args;
+  args.hostEvalVars = clauseOps.hostEvalVars;
   // TODO: Add in_reduction syms and vars.
   args.map.syms = mapSyms;
   args.map.vars = mapBaseValues;
@@ -1844,6 +2245,10 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
   genBodyOfTargetOp(converter, symTable, semaCtx, eval, targetOp, args, loc,
                     queue, item, dsp);
+
+  // Remove the host_eval information structure created for this target region.
+  if (!isTargetDevice)
+    hostEvalInfo.pop_back();
   return targetOp;
 }
 
diff --git a/flang/test/Lower/OpenMP/host-eval.f90 b/flang/test/Lower/OpenMP/host-eval.f90
new file mode 100644
index 00000000000000..32c52462b86a76
--- /dev/null
+++ b/flang/test/Lower/OpenMP/host-eval.f90
@@ -0,0 +1,157 @@
+! The "thread_limit" clause was added to the "target" construct in OpenMP 5.1.
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %s -o - | FileCheck %s --check-prefixes=BOTH,HOST
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefixes=BOTH,DEVICE
+
+! BOTH-LABEL: func.func @_QPteams
+subroutine teams()
+  ! BOTH: omp.target
+
+  ! HOST-SAME: host_eval(%{{.*}} -> %[[NUM_TEAMS:.*]], %{{.*}} -> %[[THREAD_LIMIT:.*]] : i32, i32)
+  
+  ! DEVICE-NOT: host_eval({{.*}})
+  ! DEVICE-SAME: {
+  !$omp target
+
+  ! BOTH: omp.teams
+
+  ! HOST-SAME: num_teams( to %[[NUM_TEAMS]] : i32) thread_limit(%[[THREAD_LIMIT]] : i32)
+  ! DEVICE-SAME: num_teams({{.*}}) thread_limit({{.*}})
+  !$omp teams num_teams(1) thread_limit(2)
+  call foo()
+  !$omp end teams
+
+  !$omp end target
+
+  ! BOTH: omp.teams
+  ! BOTH-SAME: num_teams({{.*}}) thread_limit({{.*}}) {
+  !$omp teams num_teams(1) thread_limit(2)
+  call foo()
+  !$omp end teams
+end subroutine teams
+
+! BOTH-LABEL: func.func @_QPdistribute_parallel_do
+subroutine distribute_parallel_do()
+  ! BOTH: omp.target
+  
+  ! HOST-SAME: host_eval(%{{.*}} -> %[[LB:.*]], %{{.*}} -> %[[UB:.*]], %{{.*}} -> %[[STEP:.*]], %{{.*}} -> %[[NUM_THREADS:.*]] : i32, i32, i32, i32)
+  
+  ! DEVICE-NOT: host_eval({{.*}})
+  ! DEVICE-SAME: {
+
+  ! BOTH: omp.teams
+  !$omp target teams
+
+  ! BOTH: omp.parallel
+
+  ! HOST-SAME: num_threads(%[[NUM_THREADS]] : i32)
+  ! DEVICE-SAME: num_threads({{.*}})
+
+  ! BOTH: omp.distribute
+  ! BOTH-NEXT: omp.wsloop
+  ! BOTH-NEXT: omp.loop_nest
+
+  ! HOST-SAME: (%{{.*}}) : i32 = (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]])
+  !$omp distribute parallel do num_threads(1)
+  do i=1,10
+    call foo()
+  end do
+  !$omp end distribute parallel do
+  !$omp end target teams
+
+  ! BOTH: omp.target
+  ! BOTH-NOT: host_eval({{.*}})
+  ! BOTH-SAME: {
+  ! BOTH: omp.teams
+  !$omp target teams
+  call foo() !< Prevents this from being SPMD.
+
+  ! BOTH: omp.parallel
+  ! BOTH-SAME: num_threads({{.*}})
+  ! BOTH: omp.distribute
+  ! BOTH-NEXT: omp.wsloop
+  !$omp distribute parallel do num_threads(1)
+  do i=1,10
+    call foo()
+  end do
+  !$omp end distribute parallel do
+  !$omp end target teams
+
+  ! BOTH: omp.teams
+  !$omp teams
+
+  ! BOTH: omp.parallel
+  ! BOTH-SAME: num_threads({{.*}})
+  ! BOTH: omp.distribute
+  ! BOTH-NEXT: omp.wsloop
+  !$omp distribute parallel do num_threads(1)
+  do i=1,10
+    call foo()
+  end do
+  !$omp end distribute parallel do
+  !$omp end teams
+end subroutine distribute_parallel_do
+
+! BOTH-LABEL: func.func @_QPdistribute_parallel_do_simd
+subroutine distribute_parallel_do_simd()
+  ! BOTH: omp.target
+  
+  ! HOST-SAME: host_eval(%{{.*}} -> %[[LB:.*]], %{{.*}} -> %[[UB:.*]], %{{.*}} -> %[[STEP:.*]], %{{.*}} -> %[[NUM_THREADS:.*]] : i32, i32, i32, i32)
+  
+  ! DEVICE-NOT: host_eval({{.*}})
+  ! DEVICE-SAME: {
+
+  ! BOTH: omp.teams
+  !$omp target teams
+
+  ! BOTH: omp.parallel
+
+  ! HOST-SAME: num_threads(%[[NUM_THREADS]] : i32)
+  ! DEVICE-SAME: num_threads({{.*}})
+
+  ! BOTH: omp.distribute
+  ! BOTH-NEXT: omp.wsloop
+  ! BOTH-NEXT: omp.simd
+  ! BOTH-NEXT: omp.loop_nest
+
+  ! HOST-SAME: (%{{.*}}) : i32 = (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]])
+  !$omp distribute parallel do simd num_threads(1)
+  do i=1,10
+    call foo()
+  end do
+  !$omp end distribute parallel do simd
+  !$omp end target teams
+
+  ! BOTH: omp.target
+  ! BOTH-NOT: host_eval({{.*}})
+  ! BOTH-SAME: {
+  ! BOTH: omp.teams
+  !$omp target teams
+  call foo() !< Prevents this from being SPMD.
+
+  ! BOTH: omp.parallel
+  ! BOTH-SAME: num_threads({{.*}})
+  ! BOTH: omp.distribute
+  ! BOTH-NEXT: omp.wsloop
+  ! BOTH-NEXT: omp.simd
+  !$omp distribute parallel do simd num_threads(1)
+  do i=1,10
+    call foo()
+  end do
+  !$omp end distribute parallel do simd
+  !$omp end target teams
+
+  ! BOTH: omp.teams
+  !$omp teams
+
+  ! BOTH: omp.parallel
+  ! BOTH-SAME: num_threads({{.*}})
+  ! BOTH: omp.distribute
+  ! BOTH-NEXT: omp.wsloop
+  ! BOTH-NEXT: omp.simd
+  !$omp distribute parallel do simd num_threads(1)
+  do i=1,10
+    call foo()
+  end do
+  !$omp end distribute parallel do simd
+  !$omp end teams
+end subroutine distribute_parallel_do_simd
diff --git a/flang/test/Lower/OpenMP/target-spmd.f90 b/flang/test/Lower/OpenMP/target-spmd.f90
new file mode 100644
index 00000000000000..43613819ccc8e9
--- /dev/null
+++ b/flang/test/Lower/OpenMP/target-spmd.f90
@@ -0,0 +1,191 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+! CHECK-LABEL: func.func @_QPdistribute_parallel_do_generic() {
+subroutine distribute_parallel_do_generic()
+  ! CHECK: omp.target
+  ! CHECK-NOT: host_eval({{.*}})
+  ! CHECK-SAME: {
+  !$omp target
+  !$omp teams
+  !$omp distribute parallel do
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end distribute parallel do
+  call bar() !< Prevents this from being SPMD.
+  !$omp end teams
+  !$omp end target
+
+  ! CHECK: omp.target
+  ! CHECK-NOT: host_eval({{.*}})
+  ! CHECK-SAME: {
+  !$omp target teams
+  !$omp distribute parallel do
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end distribute parallel do
+  call bar() !< Prevents this from being SPMD.
+  !$omp end target teams
+
+  ! CHECK: omp.target
+  ! CHECK-NOT: host_eval({{.*}})
+  ! CHECK-SAME: {
+  !$omp target teams
+  !$omp distribute parallel do
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end distribute parallel do
+
+  !$omp distribute parallel do
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end distribute parallel do
+  !$omp end target teams
+end subroutine distribute_parallel_do_generic
+
+! CHECK-LABEL: func.func @_QPdistribute_parallel_do_spmd() {
+subroutine distribute_parallel_do_spmd()
+  ! CHECK: omp.target
+  ! CHECK-SAME: host_eval({{.*}})
+  !$omp target
+  !$omp teams
+  !$omp distribute parallel do
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end distribute parallel do
+  !$omp end teams
+  !$omp end target
+
+  ! CHECK: omp.target
+  ! CHECK-SAME: host_eval({{.*}})
+  !$omp target teams
+  !$omp distribute parallel do
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end distribute parallel do
+  !$omp end target teams
+end subroutine distribute_parallel_do_spmd
+
+! CHECK-LABEL: func.func @_QPdistribute_parallel_do_simd_generic() {
+subroutine distribute_parallel_do_simd_generic()
+  ! CHECK: omp.target
+  ! CHECK-NOT: host_eval({{.*}})
+  ! CHECK-SAME: {
+  !$omp target
+  !$omp teams
+  !$omp distribute parallel do simd
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end distribute parallel do simd
+  call bar() !< Prevents this from being SPMD.
+  !$omp end teams
+  !$omp end target
+
+  ! CHECK: omp.target
+  ! CHECK-NOT: host_eval({{.*}})
+  ! CHECK-SAME: {
+  !$omp target teams
+  !$omp distribute parallel do simd
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end distribute parallel do simd
+  call bar() !< Prevents this from being SPMD.
+  !$omp end target teams
+
+  ! CHECK: omp.target
+  ! CHECK-NOT: host_eval({{.*}})
+  ! CHECK-SAME: {
+  !$omp target teams
+  !$omp distribute parallel do simd
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end distribute parallel do simd
+
+  !$omp distribute parallel do simd
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end distribute parallel do simd
+  !$omp end target teams
+end subroutine distribute_parallel_do_simd_generic
+
+! CHECK-LABEL: func.func @_QPdistribute_parallel_do_simd_spmd() {
+subroutine distribute_parallel_do_simd_spmd()
+  ! CHECK: omp.target
+  ! CHECK-SAME: host_eval({{.*}})
+  !$omp target
+  !$omp teams
+  !$omp distribute parallel do simd
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end distribute parallel do simd
+  !$omp end teams
+  !$omp end target
+
+  ! CHECK: omp.target
+  ! CHECK-SAME: host_eval({{.*}})
+  !$omp target teams
+  !$omp distribute parallel do simd
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end distribute parallel do simd
+  !$omp end target teams
+end subroutine distribute_parallel_do_simd_spmd
+
+! CHECK-LABEL: func.func @_QPteams_distribute_parallel_do_spmd() {
+subroutine teams_distribute_parallel_do_spmd()
+  ! CHECK: omp.target
+  ! CHECK-SAME: host_eval({{.*}})
+  !$omp target
+  !$omp teams distribute parallel do
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end teams distribute parallel do
+  !$omp end target
+end subroutine teams_distribute_parallel_do_spmd
+
+! CHECK-LABEL: func.func @_QPteams_distribute_parallel_do_simd_spmd() {
+subroutine teams_distribute_parallel_do_simd_spmd()
+  ! CHECK: omp.target
+  ! CHECK-SAME: host_eval({{.*}})
+  !$omp target
+  !$omp teams distribute parallel do simd
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end teams distribute parallel do simd
+  !$omp end target
+end subroutine teams_distribute_parallel_do_simd_spmd
+
+! CHECK-LABEL: func.func @_QPtarget_teams_distribute_parallel_do_spmd() {
+subroutine target_teams_distribute_parallel_do_spmd()
+  ! CHECK: omp.target
+  ! CHECK-SAME: host_eval({{.*}})
+  !$omp target teams distribute parallel do
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end target teams distribute parallel do
+end subroutine target_teams_distribute_parallel_do_spmd
+
+! CHECK-LABEL: func.func @_QPtarget_teams_distribute_parallel_do_simd_spmd() {
+subroutine target_teams_distribute_parallel_do_simd_spmd()
+  ! CHECK: omp.target
+  ! CHECK-SAME: host_eval({{.*}})
+  !$omp target teams distribute parallel do simd
+  do i = 1, 10
+    call foo(i)
+  end do
+  !$omp end target teams distribute parallel do simd
+end subroutine target_teams_distribute_parallel_do_simd_spmd
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index 1247a871f93c6d..f9a85626a3f149 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -41,6 +41,12 @@ struct DeviceTypeClauseOps {
 // Extra operation operand structures.
 //===----------------------------------------------------------------------===//
 
+/// Clauses that correspond to operations other than omp.target, but might have
+/// to be evaluated outside of a parent target region.
+using HostEvaluatedOperands =
+    detail::Clauses<LoopRelatedClauseOps, NumTeamsClauseOps,
+                    NumThreadsClauseOps, ThreadLimitClauseOps>;
+
 // TODO: Add `indirect` clause.
 using DeclareTargetOperands = detail::Clauses<DeviceTypeClauseOps>;
 



More information about the llvm-branch-commits mailing list