[Mlir-commits] [mlir] [MLIR][OpenMP] Add the host_eval clause (PR #116048)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 13 05:41:23 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-openmp

Author: Sergio Afonso (skatrak)

<details>
<summary>Changes</summary>

This patch adds the definition of a new entry block argument-defining `host_eval` clause. This is intended to implement the passthrough approach discussed in [this RFC](https://discourse.llvm.org/t/rfc-openmp-dialect-representation-of-num-teams-thread-limit-and-target-spmd/81106), for supporting host-evaluated clauses that apply to operations nested inside of `omp.target`.

---
Full diff: https://github.com/llvm/llvm-project/pull/116048.diff


4 Files Affected:

- (modified) mlir/docs/Dialects/OpenMPDialect/_index.md (+2-1) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+38) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+25-6) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+9) 


``````````diff
diff --git a/mlir/docs/Dialects/OpenMPDialect/_index.md b/mlir/docs/Dialects/OpenMPDialect/_index.md
index 3d28fe7819129f..4e5d777d6c4f7f 100644
--- a/mlir/docs/Dialects/OpenMPDialect/_index.md
+++ b/mlir/docs/Dialects/OpenMPDialect/_index.md
@@ -297,7 +297,8 @@ arguments for the region of that MLIR operation. This enables, for example, the
 introduction of private copies of the same underlying variable defined outside
 the MLIR operation the clause is attached to. Currently, clauses with this
 property can be classified into three main categories:
-  - Map-like clauses: `map`, `use_device_addr` and `use_device_ptr`.
+  - Map-like clauses: `host_eval`, `map`, `use_device_addr` and
+`use_device_ptr`.
   - Reduction-like clauses: `in_reduction`, `reduction` and `task_reduction`.
   - Privatization clauses: `private`.
 
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 855deab94b2f16..0a06c2e0335768 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -444,6 +444,44 @@ class OpenMP_HintClauseSkip<
 
 def OpenMP_HintClause : OpenMP_HintClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// Not in the spec: Clause-like structure to hold host-evaluated values.
+//===----------------------------------------------------------------------===//
+
+class OpenMP_HostEvalClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+  let traits = [
+    BlockArgOpenMPOpInterface
+  ];
+
+  let arguments = (ins
+    Variadic<AnyType>:$host_eval_vars
+  );
+
+  let extraClassDeclaration = [{
+    unsigned numHostEvalBlockArgs() {
+      return getHostEvalVars().size();
+    }
+  }];
+
+  let description = [{
+    The optional `host_eval_vars` holds values defined outside of the region of
+    the `IsolatedFromAbove` operation for which a corresponding entry block
+    argument is defined. The only legal uses for these captured values are the
+    following:
+      - `num_teams` or `thread_limit` clause of an immediately nested
+      `omp.teams` operation.
+      - If the operation is the top-level `omp.target` of a target SPMD kernel:
+        - `num_threads` clause of the nested `omp.parallel` operation.
+        - Bounds and steps of the nested `omp.loop_nest` operation.
+  }];
+}
+
+def OpenMP_HostEvalClause : OpenMP_HostEvalClauseSkip<>;
+
 //===----------------------------------------------------------------------===//
 // V5.2: [3.4] `if` clause
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 8b72689dc3fd87..c68d4c81986615 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -25,6 +25,10 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
 
   let methods = [
     // Default-implemented methods to be overriden by the corresponding clauses.
+    InterfaceMethod<"Get number of block arguments defined by `host_eval`.",
+                    "unsigned", "numHostEvalBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
     InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
                     "unsigned", "numInReductionBlockArgs", (ins), [{}], [{
       return 0;
@@ -54,10 +58,16 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
       return 0;
     }]>,
 
-    // Unified access methods for clause-associated entry block arguments.
+    // Unified access methods for start indices of clause-associated entry block
+    // arguments.
+    InterfaceMethod<"Get start index of block arguments defined by `host_eval`.",
+                    "unsigned", "getHostEvalBlockArgsStart", (ins), [{
+      return 0;
+    }]>,
     InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
                     "unsigned", "getInReductionBlockArgsStart", (ins), [{
-      return 0;
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs();
     }]>,
     InterfaceMethod<"Get start index of block arguments defined by `map`.",
                     "unsigned", "getMapBlockArgsStart", (ins), [{
@@ -91,6 +101,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
       return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
     }]>,
 
+    // Unified access methods for clause-associated entry block arguments.
+    InterfaceMethod<"Get block arguments defined by `host_eval`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getHostEvalBlockArgs", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return $_op->getRegion(0).getArguments().slice(
+          iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs());
+    }]>,
     InterfaceMethod<"Get block arguments defined by `in_reduction`.",
                     "::llvm::MutableArrayRef<::mlir::BlockArgument>",
                     "getInReductionBlockArgs", (ins), [{
@@ -147,10 +165,11 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
 
   let verify = [{
     auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
-    unsigned expectedArgs = iface.numInReductionBlockArgs() +
-        iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
-        iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() +
-        iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs();
+    unsigned expectedArgs = iface.numHostEvalBlockArgs() +
+        iface.numInReductionBlockArgs() + iface.numMapBlockArgs() +
+        iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() +
+        iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() +
+        iface.numUseDevicePtrBlockArgs();
     if ($_op->getRegion(0).getNumArguments() < expectedArgs)
       return $_op->emitOpError() << "expected at least " << expectedArgs
                                  << " entry block argument(s)";
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index a1de0831653e64..b3575b1ca4018e 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -502,6 +502,7 @@ struct ReductionParseArgs {
       : vars(vars), types(types), byref(byref), syms(syms) {}
 };
 struct AllRegionParseArgs {
+  std::optional<MapParseArgs> hostEvalArgs;
   std::optional<ReductionParseArgs> inReductionArgs;
   std::optional<MapParseArgs> mapArgs;
   std::optional<PrivateParseArgs> privateArgs;
@@ -628,6 +629,11 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
                                        AllRegionParseArgs args) {
   llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
 
+  if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
+                                 args.hostEvalArgs)))
+    return parser.emitError(parser.getCurrentLocation())
+           << "invalid `host_eval` format";
+
   if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
                                  args.inReductionArgs)))
     return parser.emitError(parser.getCurrentLocation())
@@ -789,6 +795,7 @@ struct ReductionPrintArgs {
       : vars(vars), types(types), byref(byref), syms(syms) {}
 };
 struct AllRegionPrintArgs {
+  std::optional<MapPrintArgs> hostEvalArgs;
   std::optional<ReductionPrintArgs> inReductionArgs;
   std::optional<MapPrintArgs> mapArgs;
   std::optional<PrivatePrintArgs> privateArgs;
@@ -867,6 +874,8 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
   auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
   MLIRContext *ctx = op->getContext();
 
+  printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
+                      args.hostEvalArgs);
   printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
                       args.inReductionArgs);
   printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),

``````````

</details>


https://github.com/llvm/llvm-project/pull/116048


More information about the Mlir-commits mailing list