[llvm-branch-commits] [flang] [mlir] [MLIR][OpenMP] LLVM IR translation of host_eval (PR #116052)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jan 10 03:57:12 PST 2025


================
@@ -3889,6 +3889,215 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
   return builder.saveIP();
 }
 
+/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
+/// operation and populate output variables with their corresponding host value
+/// (i.e. operand evaluated outside of the target region), based on their uses
+/// inside of the target region.
+///
+/// Loop bounds and steps are only optionally populated, if output vectors are
+/// provided.
+static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
+                                   Value &numTeamsLower, Value &numTeamsUpper,
+                                   Value &threadLimit) {
+  auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
+  for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
+                                   blockArgIface.getHostEvalBlockArgs())) {
+    Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
+
+    for (Operation *user : blockArg.getUsers()) {
+      llvm::TypeSwitch<Operation *>(user)
+          .Case([&](omp::TeamsOp teamsOp) {
+            if (teamsOp.getNumTeamsLower() == blockArg)
+              numTeamsLower = hostEvalVar;
+            else if (teamsOp.getNumTeamsUpper() == blockArg)
+              numTeamsUpper = hostEvalVar;
+            else if (teamsOp.getThreadLimit() == blockArg)
+              threadLimit = hostEvalVar;
+            else
+              llvm_unreachable("unsupported host_eval use");
+          })
+          .Case([&](omp::ParallelOp parallelOp) {
+            if (parallelOp.getNumThreads() == blockArg)
+              numThreads = hostEvalVar;
+            else
+              llvm_unreachable("unsupported host_eval use");
+          })
+          .Case([&](omp::LoopNestOp loopOp) {
+            // TODO: Extract bounds and step values.
+          })
+          .Default([](Operation *) {
+            llvm_unreachable("unsupported host_eval use");
+          });
+    }
+  }
+}
+
+/// If \p op is of the given type parameter, return it casted to that type.
+/// Otherwise, if its immediate parent operation (or some other higher-level
+/// parent, if \p immediateParent is false) is of that type, return that parent
+/// casted to the given type.
+///
+/// If \p op is \c null or neither it or its parent(s) are of the specified
+/// type, return a \c null operation.
+template <typename OpTy>
+static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
+  if (!op)
+    return OpTy();
+
+  if (OpTy casted = dyn_cast<OpTy>(op))
+    return casted;
+
+  if (immediateParent)
+    return dyn_cast_if_present<OpTy>(op->getParentOp());
+
+  return op->getParentOfType<OpTy>();
+}
+
+/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
+/// values as stated by the corresponding clauses, if constant.
+///
+/// These default values must be set before the creation of the outlined LLVM
+/// function for the target region, so that they can be used to initialize the
+/// corresponding global `ConfigurationEnvironmentTy` structure.
+static void
+initTargetDefaultAttrs(omp::TargetOp targetOp,
+                       llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
+                       bool isTargetDevice) {
+  // TODO: Handle constant 'if' clauses.
+  Operation *capturedOp = targetOp.getInnermostCapturedOmpOp();
+
+  Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
+  if (!isTargetDevice) {
+    extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
+                           threadLimit);
+  } else {
+    // In the target device, values for these clauses are not passed as
+    // host_eval, but instead evaluated prior to entry to the region. This
+    // ensures values are mapped and available inside of the target region.
+    if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
+      numTeamsLower = teamsOp.getNumTeamsLower();
+      numTeamsUpper = teamsOp.getNumTeamsUpper();
+      threadLimit = teamsOp.getThreadLimit();
+    }
+
+    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+      numThreads = parallelOp.getNumThreads();
+  }
+
+  auto extractConstInteger = [](Value value) -> std::optional<int64_t> {
----------------
skatrak wrote:

Good idea, done.

https://github.com/llvm/llvm-project/pull/116052


More information about the llvm-branch-commits mailing list