[llvm-branch-commits] [flang] [mlir] [MLIR][OpenMP] LLVM IR translation of host_eval (PR #116052)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jan 9 08:58:36 PST 2025
================
@@ -3889,6 +3889,215 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
return builder.saveIP();
}
+/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
+/// operation and populate output variables with their corresponding host value
+/// (i.e. operand evaluated outside of the target region), based on their uses
+/// inside of the target region.
+///
+/// Loop bounds and steps are only optionally populated, if output vectors are
+/// provided.
+static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
+ Value &numTeamsLower, Value &numTeamsUpper,
+ Value &threadLimit) {
+ auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
+ for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
+ blockArgIface.getHostEvalBlockArgs())) {
+ Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
+
+ for (Operation *user : blockArg.getUsers()) {
+ llvm::TypeSwitch<Operation *>(user)
+ .Case([&](omp::TeamsOp teamsOp) {
+ if (teamsOp.getNumTeamsLower() == blockArg)
+ numTeamsLower = hostEvalVar;
+ else if (teamsOp.getNumTeamsUpper() == blockArg)
+ numTeamsUpper = hostEvalVar;
+ else if (teamsOp.getThreadLimit() == blockArg)
+ threadLimit = hostEvalVar;
+ else
+ llvm_unreachable("unsupported host_eval use");
+ })
+ .Case([&](omp::ParallelOp parallelOp) {
+ if (parallelOp.getNumThreads() == blockArg)
+ numThreads = hostEvalVar;
+ else
+ llvm_unreachable("unsupported host_eval use");
+ })
+ .Case([&](omp::LoopNestOp loopOp) {
+ // TODO: Extract bounds and step values.
+ })
+ .Default([](Operation *) {
+ llvm_unreachable("unsupported host_eval use");
+ });
+ }
+ }
+}
+
+/// If \p op is of the given type parameter, return it casted to that type.
+/// Otherwise, if its immediate parent operation (or some other higher-level
+/// parent, if \p immediateParent is false) is of that type, return that parent
+/// casted to the given type.
+///
+/// If \p op is \c null or neither it or its parent(s) are of the specified
+/// type, return a \c null operation.
+template <typename OpTy>
+static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
+ if (!op)
+ return OpTy();
+
+ if (OpTy casted = dyn_cast<OpTy>(op))
+ return casted;
+
+ if (immediateParent)
+ return dyn_cast_if_present<OpTy>(op->getParentOp());
+
+ return op->getParentOfType<OpTy>();
+}
+
+/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
+/// values as stated by the corresponding clauses, if constant.
+///
+/// These default values must be set before the creation of the outlined LLVM
+/// function for the target region, so that they can be used to initialize the
+/// corresponding global `ConfigurationEnvironmentTy` structure.
+static void
+initTargetDefaultAttrs(omp::TargetOp targetOp,
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
+ bool isTargetDevice) {
+ // TODO: Handle constant 'if' clauses.
+ Operation *capturedOp = targetOp.getInnermostCapturedOmpOp();
+
+ Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
+ if (!isTargetDevice) {
+ extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
+ threadLimit);
+ } else {
+ // In the target device, values for these clauses are not passed as
+ // host_eval, but instead evaluated prior to entry to the region. This
+ // ensures values are mapped and available inside of the target region.
+ if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
+ numTeamsLower = teamsOp.getNumTeamsLower();
+ numTeamsUpper = teamsOp.getNumTeamsUpper();
+ threadLimit = teamsOp.getThreadLimit();
+ }
+
+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+ numThreads = parallelOp.getNumThreads();
+ }
+
+ auto extractConstInteger = [](Value value) -> std::optional<int64_t> {
+ if (auto constOp =
+ dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp()))
+ if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
+ return constAttr.getInt();
+
+ return std::nullopt;
+ };
+
+ // Handle clauses impacting the number of teams.
+
+ int32_t minTeamsVal = 1, maxTeamsVal = -1;
+ if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
+ // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match
+ // clang and set min and max to the same value.
+ if (numTeamsUpper) {
+ if (auto val = extractConstInteger(numTeamsUpper))
+ minTeamsVal = maxTeamsVal = *val;
----------------
agozillon wrote:
likely a dumb question, would it make sense to have an else here, to set minTeamsVal/maxTeamsVal to 0 as is done if numTeamsUpper is not retrievable? Perhaps it's fine to leave it as the default we set above here though and it's something Clang does! If so feel free to disregard :-)
https://github.com/llvm/llvm-project/pull/116052
More information about the llvm-branch-commits
mailing list