[llvm-branch-commits] [flang] [mlir] [MLIR][OpenMP] LLVM IR translation of host_eval (PR #116052)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jan 9 08:58:36 PST 2025
================
@@ -3889,6 +3889,215 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
return builder.saveIP();
}
+/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
+/// operation and populate output variables with their corresponding host value
+/// (i.e. operand evaluated outside of the target region), based on their uses
+/// inside of the target region.
+///
+/// Loop bounds and steps are only optionally populated, if output vectors are
+/// provided.
+static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
+ Value &numTeamsLower, Value &numTeamsUpper,
+ Value &threadLimit) {
+ auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
+ for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
+ blockArgIface.getHostEvalBlockArgs())) {
+ Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
+
+ for (Operation *user : blockArg.getUsers()) {
+ llvm::TypeSwitch<Operation *>(user)
+ .Case([&](omp::TeamsOp teamsOp) {
+ if (teamsOp.getNumTeamsLower() == blockArg)
+ numTeamsLower = hostEvalVar;
+ else if (teamsOp.getNumTeamsUpper() == blockArg)
+ numTeamsUpper = hostEvalVar;
+ else if (teamsOp.getThreadLimit() == blockArg)
+ threadLimit = hostEvalVar;
+ else
+ llvm_unreachable("unsupported host_eval use");
+ })
+ .Case([&](omp::ParallelOp parallelOp) {
+ if (parallelOp.getNumThreads() == blockArg)
+ numThreads = hostEvalVar;
+ else
+ llvm_unreachable("unsupported host_eval use");
+ })
+ .Case([&](omp::LoopNestOp loopOp) {
+ // TODO: Extract bounds and step values.
+ })
+ .Default([](Operation *) {
+ llvm_unreachable("unsupported host_eval use");
+ });
+ }
+ }
+}
+
+/// If \p op is of the given type parameter, return it casted to that type.
+/// Otherwise, if its immediate parent operation (or some other higher-level
+/// parent, if \p immediateParent is false) is of that type, return that parent
+/// casted to the given type.
+///
+/// If \p op is \c null or neither it or its parent(s) are of the specified
+/// type, return a \c null operation.
+template <typename OpTy>
+static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
+ if (!op)
+ return OpTy();
+
+ if (OpTy casted = dyn_cast<OpTy>(op))
+ return casted;
+
+ if (immediateParent)
+ return dyn_cast_if_present<OpTy>(op->getParentOp());
+
+ return op->getParentOfType<OpTy>();
+}
+
+/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
+/// values as stated by the corresponding clauses, if constant.
+///
+/// These default values must be set before the creation of the outlined LLVM
+/// function for the target region, so that they can be used to initialize the
+/// corresponding global `ConfigurationEnvironmentTy` structure.
+static void
+initTargetDefaultAttrs(omp::TargetOp targetOp,
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
+ bool isTargetDevice) {
+ // TODO: Handle constant 'if' clauses.
+ Operation *capturedOp = targetOp.getInnermostCapturedOmpOp();
+
+ Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
+ if (!isTargetDevice) {
+ extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
+ threadLimit);
+ } else {
+ // In the target device, values for these clauses are not passed as
+ // host_eval, but instead evaluated prior to entry to the region. This
+ // ensures values are mapped and available inside of the target region.
+ if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
+ numTeamsLower = teamsOp.getNumTeamsLower();
+ numTeamsUpper = teamsOp.getNumTeamsUpper();
+ threadLimit = teamsOp.getThreadLimit();
+ }
+
+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+ numThreads = parallelOp.getNumThreads();
+ }
+
+ auto extractConstInteger = [](Value value) -> std::optional<int64_t> {
----------------
agozillon wrote:
might honestly be worth just making this a function as it's possibly useful elsewhere in the file, but I'll leave that up to you as usual! :-)
https://github.com/llvm/llvm-project/pull/116052
More information about the llvm-branch-commits
mailing list