[Mlir-commits] [mlir] [MLIR][OpenMP] Add the host_eval clause (PR #116048)
Sergio Afonso
llvmlistbot at llvm.org
Wed Dec 4 04:23:55 PST 2024
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/116048
>From d74cf356919ba5b7bda60dee9217b34e7140019a Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 7 Nov 2024 16:53:51 +0000
Subject: [PATCH] [MLIR][OpenMP] Add the host_eval clause
This patch adds the definition of a new entry block argument-defining
`host_eval` clause. This is intended to implement the passthrough approach
discussed in [this RFC](https://discourse.llvm.org/t/rfc-openmp-dialect-representation-of-num-teams-thread-limit-and-target-spmd/81106),
for supporting host-evaluated clauses that apply to operations nested inside of
`omp.target`.
---
mlir/docs/Dialects/OpenMPDialect/_index.md | 3 +-
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 38 +++++++++++++++++++
.../Dialect/OpenMP/OpenMPOpsInterfaces.td | 31 ++++++++++++---
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 9 +++++
4 files changed, 74 insertions(+), 7 deletions(-)
diff --git a/mlir/docs/Dialects/OpenMPDialect/_index.md b/mlir/docs/Dialects/OpenMPDialect/_index.md
index 3d28fe7819129f..03d5b95217cce0 100644
--- a/mlir/docs/Dialects/OpenMPDialect/_index.md
+++ b/mlir/docs/Dialects/OpenMPDialect/_index.md
@@ -297,7 +297,8 @@ arguments for the region of that MLIR operation. This enables, for example, the
introduction of private copies of the same underlying variable defined outside
the MLIR operation the clause is attached to. Currently, clauses with this
property can be classified into three main categories:
- - Map-like clauses: `map`, `use_device_addr` and `use_device_ptr`.
+ - Map-like clauses: `host_eval` (compiler internal, not defined by the OpenMP
+ specification), `map`, `use_device_addr` and `use_device_ptr`.
- Reduction-like clauses: `in_reduction`, `reduction` and `task_reduction`.
- Privatization clauses: `private`.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 855deab94b2f16..0a06c2e0335768 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -444,6 +444,44 @@ class OpenMP_HintClauseSkip<
def OpenMP_HintClause : OpenMP_HintClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// Not in the spec: Clause-like structure to hold host-evaluated values.
+//===----------------------------------------------------------------------===//
+
+class OpenMP_HostEvalClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let traits = [
+ BlockArgOpenMPOpInterface
+ ];
+
+ let arguments = (ins
+ Variadic<AnyType>:$host_eval_vars
+ );
+
+ let extraClassDeclaration = [{
+ unsigned numHostEvalBlockArgs() {
+ return getHostEvalVars().size();
+ }
+ }];
+
+ let description = [{
+ The optional `host_eval_vars` holds values defined outside of the region of
+ the `IsolatedFromAbove` operation for which a corresponding entry block
+ argument is defined. The only legal uses for these captured values are the
+ following:
+ - `num_teams` or `thread_limit` clause of an immediately nested
+ `omp.teams` operation.
+ - If the operation is the top-level `omp.target` of a target SPMD kernel:
+ - `num_threads` clause of the nested `omp.parallel` operation.
+ - Bounds and steps of the nested `omp.loop_nest` operation.
+ }];
+}
+
+def OpenMP_HostEvalClause : OpenMP_HostEvalClauseSkip<>;
+
//===----------------------------------------------------------------------===//
// V5.2: [3.4] `if` clause
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 8b72689dc3fd87..c68d4c81986615 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -25,6 +25,10 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
let methods = [
// Default-implemented methods to be overriden by the corresponding clauses.
+ InterfaceMethod<"Get number of block arguments defined by `host_eval`.",
+ "unsigned", "numHostEvalBlockArgs", (ins), [{}], [{
+ return 0;
+ }]>,
InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
"unsigned", "numInReductionBlockArgs", (ins), [{}], [{
return 0;
@@ -54,10 +58,16 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
return 0;
}]>,
- // Unified access methods for clause-associated entry block arguments.
+ // Unified access methods for start indices of clause-associated entry block
+ // arguments.
+ InterfaceMethod<"Get start index of block arguments defined by `host_eval`.",
+ "unsigned", "getHostEvalBlockArgsStart", (ins), [{
+ return 0;
+ }]>,
InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
"unsigned", "getInReductionBlockArgsStart", (ins), [{
- return 0;
+ auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+ return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs();
}]>,
InterfaceMethod<"Get start index of block arguments defined by `map`.",
"unsigned", "getMapBlockArgsStart", (ins), [{
@@ -91,6 +101,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
}]>,
+ // Unified access methods for clause-associated entry block arguments.
+ InterfaceMethod<"Get block arguments defined by `host_eval`.",
+ "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+ "getHostEvalBlockArgs", (ins), [{
+ auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+ return $_op->getRegion(0).getArguments().slice(
+ iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs());
+ }]>,
InterfaceMethod<"Get block arguments defined by `in_reduction`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
"getInReductionBlockArgs", (ins), [{
@@ -147,10 +165,11 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
let verify = [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
- unsigned expectedArgs = iface.numInReductionBlockArgs() +
- iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
- iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() +
- iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs();
+ unsigned expectedArgs = iface.numHostEvalBlockArgs() +
+ iface.numInReductionBlockArgs() + iface.numMapBlockArgs() +
+ iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() +
+ iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() +
+ iface.numUseDevicePtrBlockArgs();
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
return $_op->emitOpError() << "expected at least " << expectedArgs
<< " entry block argument(s)";
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 6d7dbbf58bbda7..f626d18e9f4d69 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -504,6 +504,7 @@ struct ReductionParseArgs {
: vars(vars), types(types), byref(byref), syms(syms) {}
};
struct AllRegionParseArgs {
+ std::optional<MapParseArgs> hostEvalArgs;
std::optional<ReductionParseArgs> inReductionArgs;
std::optional<MapParseArgs> mapArgs;
std::optional<PrivateParseArgs> privateArgs;
@@ -647,6 +648,11 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
AllRegionParseArgs args) {
llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
+ if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
+ args.hostEvalArgs)))
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid `host_eval` format";
+
if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
args.inReductionArgs)))
return parser.emitError(parser.getCurrentLocation())
@@ -812,6 +818,7 @@ struct ReductionPrintArgs {
: vars(vars), types(types), byref(byref), syms(syms) {}
};
struct AllRegionPrintArgs {
+ std::optional<MapPrintArgs> hostEvalArgs;
std::optional<ReductionPrintArgs> inReductionArgs;
std::optional<MapPrintArgs> mapArgs;
std::optional<PrivatePrintArgs> privateArgs;
@@ -902,6 +909,8 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
MLIRContext *ctx = op->getContext();
+ printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
+ args.hostEvalArgs);
printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
args.inReductionArgs);
printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
More information about the Mlir-commits
mailing list