[Mlir-commits] [mlir] 8906343 - [MLIR][OpenMP] Add the host_eval clause (#116048)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 14 02:19:50 PST 2025
Author: Sergio Afonso
Date: 2025-01-14T10:19:45Z
New Revision: 89063433792699c5913ba116cab09b534c549e56
URL: https://github.com/llvm/llvm-project/commit/89063433792699c5913ba116cab09b534c549e56
DIFF: https://github.com/llvm/llvm-project/commit/89063433792699c5913ba116cab09b534c549e56.diff
LOG: [MLIR][OpenMP] Add the host_eval clause (#116048)
This patch adds the definition of a new entry block argument-defining
`host_eval` clause. This is intended to implement the passthrough
approach discussed in [this
RFC](https://discourse.llvm.org/t/rfc-openmp-dialect-representation-of-num-teams-thread-limit-and-target-spmd/81106),
for supporting host-evaluated clauses that apply to operations nested
inside of `omp.target`.
Added:
Modified:
mlir/docs/Dialects/OpenMPDialect/_index.md
mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/OpenMPDialect/_index.md b/mlir/docs/Dialects/OpenMPDialect/_index.md
index 3d28fe7819129f..03d5b95217cce0 100644
--- a/mlir/docs/Dialects/OpenMPDialect/_index.md
+++ b/mlir/docs/Dialects/OpenMPDialect/_index.md
@@ -297,7 +297,8 @@ arguments for the region of that MLIR operation. This enables, for example, the
introduction of private copies of the same underlying variable defined outside
the MLIR operation the clause is attached to. Currently, clauses with this
property can be classified into three main categories:
- - Map-like clauses: `map`, `use_device_addr` and `use_device_ptr`.
+ - Map-like clauses: `host_eval` (compiler internal, not defined by the OpenMP
+ specification), `map`, `use_device_addr` and `use_device_ptr`.
- Reduction-like clauses: `in_reduction`, `reduction` and `task_reduction`.
- Privatization clauses: `private`.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 98d2e80ed2d81d..8af054be322a55 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -470,6 +470,44 @@ class OpenMP_HintClauseSkip<
def OpenMP_HintClause : OpenMP_HintClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// Not in the spec: Clause-like structure to hold host-evaluated values.
+//===----------------------------------------------------------------------===//
+
+class OpenMP_HostEvalClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let traits = [
+ BlockArgOpenMPOpInterface, IsolatedFromAbove
+ ];
+
+ let arguments = (ins
+ Variadic<AnyType>:$host_eval_vars
+ );
+
+ let extraClassDeclaration = [{
+ unsigned numHostEvalBlockArgs() {
+ return getHostEvalVars().size();
+ }
+ }];
+
+ let description = [{
+ The optional `host_eval_vars` holds values defined outside of the region of
+ the `IsolatedFromAbove` operation for which a corresponding entry block
+ argument is defined. The only legal uses for these captured values are the
+ following:
+ - `num_teams` or `thread_limit` clause of an immediately nested
+ `omp.teams` operation.
+ - If the operation is the top-level `omp.target` of a target SPMD kernel:
+ - `num_threads` clause of the nested `omp.parallel` operation.
+ - Bounds and steps of the nested `omp.loop_nest` operation.
+ }];
+}
+
+def OpenMP_HostEvalClause : OpenMP_HostEvalClauseSkip<>;
+
//===----------------------------------------------------------------------===//
// V5.2: [3.4] `if` clause
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index c4cf0f7afb3a34..c863e5772032c2 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -25,6 +25,10 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
let methods = [
// Default-implemented methods to be overriden by the corresponding clauses.
+ InterfaceMethod<"Get number of block arguments defined by `host_eval`.",
+ "unsigned", "numHostEvalBlockArgs", (ins), [{}], [{
+ return 0;
+ }]>,
InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
"unsigned", "numInReductionBlockArgs", (ins), [{}], [{
return 0;
@@ -54,10 +58,16 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
return 0;
}]>,
- // Unified access methods for clause-associated entry block arguments.
+ // Unified access methods for start indices of clause-associated entry block
+ // arguments.
+ InterfaceMethod<"Get start index of block arguments defined by `host_eval`.",
+ "unsigned", "getHostEvalBlockArgsStart", (ins), [{
+ return 0;
+ }]>,
InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
"unsigned", "getInReductionBlockArgsStart", (ins), [{
- return 0;
+ auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+ return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs();
}]>,
InterfaceMethod<"Get start index of block arguments defined by `map`.",
"unsigned", "getMapBlockArgsStart", (ins), [{
@@ -91,6 +101,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
}]>,
+ // Unified access methods for clause-associated entry block arguments.
+ InterfaceMethod<"Get block arguments defined by `host_eval`.",
+ "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+ "getHostEvalBlockArgs", (ins), [{
+ auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+ return $_op->getRegion(0).getArguments().slice(
+ iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs());
+ }]>,
InterfaceMethod<"Get block arguments defined by `in_reduction`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
"getInReductionBlockArgs", (ins), [{
@@ -147,10 +165,11 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
let verify = [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
- unsigned expectedArgs = iface.numInReductionBlockArgs() +
- iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
- iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() +
- iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs();
+ unsigned expectedArgs = iface.numHostEvalBlockArgs() +
+ iface.numInReductionBlockArgs() + iface.numMapBlockArgs() +
+ iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() +
+ iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() +
+ iface.numUseDevicePtrBlockArgs();
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
return $_op->emitOpError() << "expected at least " << expectedArgs
<< " entry block argument(s)";
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index ca7e08e9f18b5f..2235fe2ee668d0 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -504,6 +504,7 @@ struct ReductionParseArgs {
: vars(vars), types(types), byref(byref), syms(syms) {}
};
struct AllRegionParseArgs {
+ std::optional<MapParseArgs> hostEvalArgs;
std::optional<ReductionParseArgs> inReductionArgs;
std::optional<MapParseArgs> mapArgs;
std::optional<PrivateParseArgs> privateArgs;
@@ -647,6 +648,11 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
AllRegionParseArgs args) {
llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
+ if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
+ args.hostEvalArgs)))
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid `host_eval` format";
+
if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
args.inReductionArgs)))
return parser.emitError(parser.getCurrentLocation())
@@ -812,6 +818,7 @@ struct ReductionPrintArgs {
: vars(vars), types(types), byref(byref), syms(syms) {}
};
struct AllRegionPrintArgs {
+ std::optional<MapPrintArgs> hostEvalArgs;
std::optional<ReductionPrintArgs> inReductionArgs;
std::optional<MapPrintArgs> mapArgs;
std::optional<PrivatePrintArgs> privateArgs;
@@ -902,6 +909,8 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
MLIRContext *ctx = op->getContext();
+ printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
+ args.hostEvalArgs);
printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
args.inReductionArgs);
printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
More information about the Mlir-commits
mailing list