[llvm-branch-commits] [flang] [mlir] [MLIR][OpenMP] LLVM IR translation of host_eval (PR #116052)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jan 10 03:57:12 PST 2025


================
@@ -3889,6 +3889,215 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
   return builder.saveIP();
 }
 
+/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
+/// operation and populate output variables with their corresponding host value
+/// (i.e. operand evaluated outside of the target region), based on their uses
+/// inside of the target region.
+///
+/// Loop bounds and steps are only optionally populated, if output vectors are
+/// provided.
+static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
+                                   Value &numTeamsLower, Value &numTeamsUpper,
+                                   Value &threadLimit) {
+  auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
+  for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
+                                   blockArgIface.getHostEvalBlockArgs())) {
+    Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
+
+    for (Operation *user : blockArg.getUsers()) {
+      llvm::TypeSwitch<Operation *>(user)
+          .Case([&](omp::TeamsOp teamsOp) {
+            if (teamsOp.getNumTeamsLower() == blockArg)
+              numTeamsLower = hostEvalVar;
+            else if (teamsOp.getNumTeamsUpper() == blockArg)
+              numTeamsUpper = hostEvalVar;
+            else if (teamsOp.getThreadLimit() == blockArg)
+              threadLimit = hostEvalVar;
+            else
+              llvm_unreachable("unsupported host_eval use");
+          })
+          .Case([&](omp::ParallelOp parallelOp) {
+            if (parallelOp.getNumThreads() == blockArg)
+              numThreads = hostEvalVar;
+            else
+              llvm_unreachable("unsupported host_eval use");
+          })
+          .Case([&](omp::LoopNestOp loopOp) {
+            // TODO: Extract bounds and step values.
----------------
skatrak wrote:

This case must have already been checked and reported by the `checkImplementationStatus` of the `omp.target` operation, stopping the translation process. So, an `llvm_unreachable` makes sense to add here, thanks for the suggestion. I also added that explanation to the comment above.

https://github.com/llvm/llvm-project/pull/116052


More information about the llvm-branch-commits mailing list