[llvm] [mlir] [mlir][llvm] Translation support for task detach (PR #116601)

via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 18 02:56:45 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: None (NimishMishra)

<details>
<summary>Changes</summary>

This PR adds translation support for task detach. Essentially, if the `detach` clause is present on a task, emit a `__kmpc_task_allow_completion_event` on it, and store its return (of type `kmp_event_t*`) into the `event_handle`.

---
Full diff: https://github.com/llvm/llvm-project/pull/116601.diff


9 Files Affected:

- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+4-1) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+20-6) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+31) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+1-1) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+1-1) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+2-1) 
- (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+1-1) 
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+8-2) 
- (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+17) 


``````````diff
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 3afb9d84278e81..817fcdf25e2e70 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1262,12 +1262,15 @@ class OpenMPIRBuilder {
   ///                    cannot be resumed until execution of the structured
   ///                    block that is associated with the generated task is
   ///                    completed.
+  /// \param EventHandle If present, signifies the event handle as part of
+  /// 			 the detach clause
   InsertPointOrErrorTy createTask(const LocationDescription &Loc,
                                   InsertPointTy AllocaIP,
                                   BodyGenCallbackTy BodyGenCB, bool Tied = true,
                                   Value *Final = nullptr,
                                   Value *IfCondition = nullptr,
-                                  SmallVector<DependData> Dependencies = {});
+                                  SmallVector<DependData> Dependencies = {},
+                                  Value *EventHandle = nullptr);
 
   /// Generator for the taskgroup construct
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index d2e4dc1c85dfd2..61311e7f7652ff 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1816,11 +1816,10 @@ static Value *emitTaskDependencies(
   return DepArray;
 }
 
-OpenMPIRBuilder::InsertPointOrErrorTy
-OpenMPIRBuilder::createTask(const LocationDescription &Loc,
-                            InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
-                            bool Tied, Value *Final, Value *IfCondition,
-                            SmallVector<DependData> Dependencies) {
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
+    const LocationDescription &Loc, InsertPointTy AllocaIP,
+    BodyGenCallbackTy BodyGenCB, bool Tied, Value *Final, Value *IfCondition,
+    SmallVector<DependData> Dependencies, Value *EventHandle) {
 
   if (!updateToLocation(Loc))
     return InsertPointTy();
@@ -1866,7 +1865,8 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
       Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));
 
   OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
-                      TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable {
+                      EventHandle, TaskAllocaBB,
+                      ToBeDeleted](Function &OutlinedFn) mutable {
     // Replace the Stale CI by appropriate RTL function call.
     assert(OutlinedFn.getNumUses() == 1 &&
            "there must be a single user for the outlined function");
@@ -1930,6 +1930,20 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
         TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
                       /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
                       /*task_func=*/&OutlinedFn});
+    // Emit detach clause initialization.
+    // evt = (typeof(evt))__kmpc_task_allow_completion_event(loc, tid,
+    // task_descriptor);
+    if (EventHandle) {
+      Function *TaskDetachFn = getOrCreateRuntimeFunctionPtr(
+          OMPRTL___kmpc_task_allow_completion_event);
+      llvm::Value *EventVal =
+          Builder.CreateCall(TaskDetachFn, {Ident, ThreadID, TaskData});
+      llvm::Value *EventHandleAddr =
+          Builder.CreatePointerBitCastOrAddrSpaceCast(EventHandle,
+                                                      Builder.getPtrTy(0));
+      EventVal = Builder.CreatePtrToInt(EventVal, Builder.getInt64Ty());
+      Builder.CreateStore(EventVal, EventHandleAddr);
+    }
 
     // Copy the arguments for outlined function
     if (HasShareds) {
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 855deab94b2f16..3188feffd52707 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -909,6 +909,37 @@ class OpenMP_ParallelizationLevelClauseSkip<
 
 def OpenMP_ParallelizationLevelClause : OpenMP_ParallelizationLevelClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// OpenMPV5.2: [12.5.2] `detach` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_DetachClauseSkip<
+	bit traits = false, bit arguments = false, bit assemblyFormat = false,
+	bit description = false, bit extraClassDeclaration = false
+	> : OpenMP_Clause<traits, arguments, assemblyFormat, description, extraClassDeclaration> {
+
+	let traits = [
+		BlockArgOpenMPOpInterface
+	];
+
+	let arguments = (ins
+		Optional<OpenMP_PointerLikeType>:$event_handle
+	);
+
+	let optAssemblyFormat = [{
+		`detach` `(` $event_handle `:` type($event_handle) `)`
+	}];
+
+	let description = [{
+		The detach clause specifies that the task generated by the construct on which it appears is a
+	detachable task. A new allow-completion event is created and connected to the completion of the
+	associated task region. The original event-handle is updated to represent that allow-completion
+	event before the task data environment is created.
+	}];
+}
+
+def OpenMP_DetachClause : OpenMP_DetachClauseSkip<>;
+
 //===----------------------------------------------------------------------===//
 // V5.2: [12.4] `priority` clause
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index f79a3eb88e4b5e..024581e5363a32 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -621,7 +621,7 @@ def TaskOp : OpenMP_Op<"task", traits = [
     // TODO: Complete clause list (affinity, detach).
     OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_FinalClause,
     OpenMP_IfClause, OpenMP_InReductionClause, OpenMP_MergeableClause,
-    OpenMP_PriorityClause, OpenMP_PrivateClause, OpenMP_UntiedClause
+    OpenMP_PriorityClause, OpenMP_PrivateClause, OpenMP_UntiedClause, OpenMP_DetachClause
   ], singleRegion = true> {
   let summary = "task construct";
   let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 19e0fa30a75715..e06cc676be1b47 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2215,7 +2215,7 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
                 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
                 clauses.priority, /*private_vars=*/clauses.privateVars,
                 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
-                clauses.untied);
+                clauses.untied, clauses.eventHandle);
 }
 
 LogicalResult TaskOp::verify() {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index a39b27aa9e12dc..abf75259ed996e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1665,7 +1665,8 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
       moduleTranslation.getOpenMPBuilder()->createTask(
           ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
           moduleTranslation.lookupValue(taskOp.getFinal()),
-          moduleTranslation.lookupValue(taskOp.getIfExpr()), dds);
+          moduleTranslation.lookupValue(taskOp.getIfExpr()), dds,
+          moduleTranslation.lookupValue(taskOp.getEventHandle()));
 
   if (failed(handleError(afterIP, *taskOp)))
     return failure();
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index aa41eea44f3ef4..ea0e478ce21c37 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1634,7 +1634,7 @@ func.func @omp_single_copyprivate(%data_var : memref<i32>) -> () {
 // -----
 
 func.func @omp_task_depend(%data_var: memref<i32>) {
-  // expected-error @below {{op expected as many depend values as depend variables}}
+  // expected-error @below {{'omp.task' op operand count (1) does not match with the total size (0) specified in attribute 'operandSegmentSizes'}} 
     "omp.task"(%data_var) ({
       "omp.terminator"() : () -> ()
     }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 69c53d1f77e841..02208117ba85b2 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1975,8 +1975,8 @@ func.func @omp_single_copyprivate(%data_var: memref<i32>) {
 }
 
 // CHECK-LABEL: @omp_task
-// CHECK-SAME: (%[[bool_var:.*]]: i1, %[[i64_var:.*]]: i64, %[[i32_var:.*]]: i32, %[[data_var:.*]]: memref<i32>)
-func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memref<i32>) {
+// CHECK-SAME: (%[[bool_var:.*]]: i1, %[[i64_var:.*]]: i64, %[[i32_var:.*]]: i32, %[[data_var:.*]]: memref<i32>, %[[event_handle:.*]]: !llvm.ptr)
+func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memref<i32>, %event_handle : !llvm.ptr) {
 
   // Checking simple task
   // CHECK: omp.task {
@@ -2055,6 +2055,12 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr
     omp.terminator
   }
 
+  // Checking detach clause
+  // CHECK: omp.task detach(%[[event_handle]] : !llvm.ptr)
+  omp.task detach(%event_handle : !llvm.ptr){
+    omp.terminator
+  }
+
   // Checking multiple clauses
   // CHECK: omp.task allocate(%[[data_var]] : memref<i32> -> %[[data_var]] : memref<i32>)
   omp.task allocate(%data_var : memref<i32> -> %data_var : memref<i32>)
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index cdf94b1ceae11b..eeec03bf515b3f 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2505,6 +2505,23 @@ llvm.mlir.global internal @_QFsubEx() : i32
 
 // -----
 
+// CHECK-LABEL: define void @omp_task_detach
+// CHECK-SAME: (ptr %[[event_handle:.*]])
+llvm.func @omp_task_detach(%event_handle : !llvm.ptr){
+  // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
+  // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
+  // CHECK: %[[return_val:.*]] = call ptr @__kmpc_task_allow_completion_event(ptr {{.*}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
+  // CHECK: %[[conv:.*]] = ptrtoint ptr %[[return_val]] to i64
+  // CHECK: store i64 %[[conv]], ptr %[[event_handle]], align 4
+  // CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
+  omp.task detach(%event_handle : !llvm.ptr){
+   omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
 // CHECK-LABEL: define void @omp_task
 // CHECK-SAME: (i32 %[[x:.+]], i32 %[[y:.+]], ptr %[[zaddr:.+]])
 llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/116601


More information about the llvm-commits mailing list