[llvm] 9b57b16 - [OMPIRBuilder] Fix shared clause for task construct
    Prabhdeep Singh Soni via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Fri Sep 15 09:20:05 PDT 2023
    
    
  
Author: Prabhdeep Singh Soni
Date: 2023-09-15T12:19:47-04:00
New Revision: 9b57b167bb4d849b6803e28f638b970f493511f9
URL: https://github.com/llvm/llvm-project/commit/9b57b167bb4d849b6803e28f638b970f493511f9
DIFF: https://github.com/llvm/llvm-project/commit/9b57b167bb4d849b6803e28f638b970f493511f9.diff
LOG: [OMPIRBuilder] Fix shared clause for task construct
This patch fixes the shared clause for the task construct with multiple
shared variables. The shareds field in the kmp_task_t is not an inline
array in the struct, rather it is a pointer to an array. With an inline
array, the pointer dereference to the outlined function body of the task
would segmentation fault when accessed by the runtime.
Reviewed By: kiranchandramohan, jdoerfert
Differential Revision: https://reviews.llvm.org/D158462
Added: 
    
Modified: 
    llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
    llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
    llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
    mlir/test/Target/LLVMIR/openmp-llvm.mlir
Removed: 
    
################################################################################
diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
index c4218326280b2ba..176b883fe68f7ad 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -95,6 +95,7 @@ __OMP_STRUCT_TYPE(KernelArgs, __tgt_kernel_arguments, false, Int32, Int32, VoidP
 		  Int64, Int64, Int32Arr3Ty, Int32Arr3Ty, Int32)
 __OMP_STRUCT_TYPE(AsyncInfo, __tgt_async_info, false, Int8Ptr)
 __OMP_STRUCT_TYPE(DependInfo, kmp_dep_info, false, SizeTy, SizeTy, Int8)
+__OMP_STRUCT_TYPE(Task, kmp_task_ompbuilder_t, false, VoidPtr, VoidPtr, Int32, VoidPtr, VoidPtr)
 __OMP_STRUCT_TYPE(ConfigurationEnvironment, ConfigurationEnvironmentTy, false,
                   Int8, Int8, Int8)
 __OMP_STRUCT_TYPE(DynamicEnvironment, DynamicEnvironmentTy, false, Int16)
diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 8e9200aa1821985..1ace7d5b97ffc96 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1555,9 +1555,9 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
            "there must be a single user for the outlined function");
     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
 
-    // HasTaskData is true if any variables are captured in the outlined region,
+    // HasShareds is true if any variables are captured in the outlined region,
     // false otherwise.
-    bool HasTaskData = StaleCI->arg_size() > 0;
+    bool HasShareds = StaleCI->arg_size() > 0;
     Builder.SetInsertPoint(StaleCI);
 
     // Gather the arguments for emitting the runtime call for
@@ -1585,8 +1585,15 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
     // Argument - `sizeof_kmp_task_t` (TaskSize)
     // Tasksize refers to the size in bytes of kmp_task_t data structure
     // including private vars accessed in task.
-    Value *TaskSize = Builder.getInt64(0);
-    if (HasTaskData) {
+    // TODO: add kmp_task_t_with_privates (privates)
+    Value *TaskSize = Builder.getInt64(
+        divideCeil(M.getDataLayout().getTypeSizeInBits(Task), 8));
+
+    // Argument - `sizeof_shareds` (SharedsSize)
+    // SharedsSize refers to the shareds array size in the kmp_task_t data
+    // structure.
+    Value *SharedsSize = Builder.getInt64(0);
+    if (HasShareds) {
       AllocaInst *ArgStructAlloca =
           dyn_cast<AllocaInst>(StaleCI->getArgOperand(0));
       assert(ArgStructAlloca &&
@@ -1596,19 +1603,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
           dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
       assert(ArgStructType && "Unable to find struct type corresponding to "
                               "arguments for extracted function");
-      TaskSize =
+      SharedsSize =
           Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
     }
 
-    // TODO: Argument - sizeof_shareds
-
     // Argument - task_entry (the wrapper function)
-    // If the outlined function has some captured variables (i.e. HasTaskData is
+    // If the outlined function has some captured variables (i.e. HasShareds is
     // true), then the wrapper function will have an additional argument (the
     // struct containing captured variables). Otherwise, no such argument will
     // be present.
     SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty()};
-    if (HasTaskData)
+    if (HasShareds)
       WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType());
     FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
         (Twine(OutlinedFn.getName()) + ".wrapper").str(),
@@ -1617,19 +1622,19 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
 
     // Emit the @__kmpc_omp_task_alloc runtime call
     // The runtime call returns a pointer to an area where the task captured
-    // variables must be copied before the task is run (NewTaskData)
-    CallInst *NewTaskData = Builder.CreateCall(
-        TaskAllocFn,
-        {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
-         /*sizeof_task=*/TaskSize, /*sizeof_shared=*/Builder.getInt64(0),
-         /*task_func=*/WrapperFunc});
+    // variables must be copied before the task is run (TaskData)
+    CallInst *TaskData = Builder.CreateCall(
+        TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
+                      /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
+                      /*task_func=*/WrapperFunc});
 
     // Copy the arguments for outlined function
-    if (HasTaskData) {
-      Value *TaskData = StaleCI->getArgOperand(0);
+    if (HasShareds) {
+      Value *Shareds = StaleCI->getArgOperand(0);
       Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
-      Builder.CreateMemCpy(NewTaskData, Alignment, TaskData, Alignment,
-                           TaskSize);
+      Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
+      Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
+                           SharedsSize);
     }
 
     Value *DepArrayPtr = nullptr;
@@ -1705,12 +1710,12 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
       Function *TaskCompleteFn =
           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
-      Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, NewTaskData});
-      if (HasTaskData)
-        Builder.CreateCall(WrapperFunc, {ThreadID, NewTaskData});
+      Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
+      if (HasShareds)
+        Builder.CreateCall(WrapperFunc, {ThreadID, TaskData});
       else
         Builder.CreateCall(WrapperFunc, {ThreadID});
-      Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, NewTaskData});
+      Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
       Builder.SetInsertPoint(ThenTI);
     }
 
@@ -1719,14 +1724,14 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
       Builder.CreateCall(
           TaskFn,
-          {Ident, ThreadID, NewTaskData, Builder.getInt32(Dependencies.size()),
+          {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
            DepArrayPtr, ConstantInt::get(Builder.getInt32Ty(), 0),
            ConstantPointerNull::get(Type::getInt8PtrTy(M.getContext()))});
 
     } else {
       // Emit the @__kmpc_omp_task runtime call to spawn the task
       Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
-      Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});
+      Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
     }
 
     StaleCI->eraseFromParent();
@@ -1735,10 +1740,13 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
     BasicBlock *WrapperEntryBB =
         BasicBlock::Create(M.getContext(), "", WrapperFunc);
     Builder.SetInsertPoint(WrapperEntryBB);
-    if (HasTaskData)
-      Builder.CreateCall(&OutlinedFn, {WrapperFunc->getArg(1)});
-    else
+    if (HasShareds) {
+      llvm::Value *Shareds =
+          Builder.CreateLoad(VoidPtr, WrapperFunc->getArg(1));
+      Builder.CreateCall(&OutlinedFn, {Shareds});
+    } else {
       Builder.CreateCall(&OutlinedFn);
+    }
     Builder.CreateRet(Builder.getInt32(0));
   };
 
diff  --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index bc1687cae25f1d9..2026824416f3e3c 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5397,19 +5397,29 @@ TEST_F(OpenMPIRBuilderTest, CreateTask) {
   ConstantInt *DataSize =
       dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(3));
   ASSERT_NE(DataSize, nullptr);
-  EXPECT_EQ(DataSize->getSExtValue(), 24); // 64-bit pointer + 128-bit integer
+  EXPECT_EQ(DataSize->getSExtValue(), 40);
 
-  // TODO: Verify size of shared clause variables
+  ConstantInt *SharedsSize =
+      dyn_cast<ConstantInt>(TaskAllocCall->getOperand(4));
+  EXPECT_EQ(SharedsSize->getSExtValue(),
+            24); // 64-bit pointer + 128-bit integer
 
   // Verify Wrapper function
   Function *WrapperFunc =
       dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
   ASSERT_NE(WrapperFunc, nullptr);
+
+  LoadInst *SharedsLoad = dyn_cast<LoadInst>(WrapperFunc->begin()->begin());
+  ASSERT_NE(SharedsLoad, nullptr);
+  EXPECT_EQ(SharedsLoad->getPointerOperand(), WrapperFunc->getArg(1));
+
   EXPECT_FALSE(WrapperFunc->isDeclaration());
-  CallInst *OutlinedFnCall = dyn_cast<CallInst>(WrapperFunc->begin()->begin());
+  CallInst *OutlinedFnCall =
+      dyn_cast<CallInst>(++WrapperFunc->begin()->begin());
   ASSERT_NE(OutlinedFnCall, nullptr);
   EXPECT_EQ(WrapperFunc->getArg(0)->getType(), Builder.getInt32Ty());
-  EXPECT_EQ(OutlinedFnCall->getArgOperand(0), WrapperFunc->getArg(1));
+  EXPECT_EQ(OutlinedFnCall->getArgOperand(0),
+            WrapperFunc->getArg(1)->uses().begin()->getUser());
 
   // Verify the presence of `trunc` and `icmp` instructions in Outlined function
   Function *OutlinedFn = OutlinedFnCall->getCalledFunction();
diff  --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 1dff33cdc22ce12..28b0113a19d61b8 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2208,7 +2208,7 @@ llvm.mlir.global internal @_QFsubEx() : i32
 llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
   // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
   // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
-  // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 0,
+  // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40,
   // CHECK-SAME:  i64 0, ptr @[[wrapper_fn:.+]])
   // CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
   omp.task {
@@ -2258,7 +2258,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
 llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
   // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
   // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
-  // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 0,
+  // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40,
   // CHECK-SAME:  i64 0, ptr @[[wrapper_fn:.+]])
   // CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]], {{.*}})
   omp.task depend(taskdependin -> %zaddr : !llvm.ptr<i32>) {
@@ -2303,9 +2303,10 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
     llvm.store %
diff , %zaddr : !llvm.ptr<i32>
     // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
     // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
-    // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 16, i64 0,
+    // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 16,
     // CHECK-SAME: ptr @[[wrapper_fn:.+]])
-    // CHECK: call void @llvm.memcpy.p0.p0.i64(ptr {{.+}} %[[task_data]], ptr {{.+}}, i64 16, i1 false)
+    // CHECK: %[[shareds:.+]] = load ptr, ptr %[[task_data]]
+    // CHECK: call void @llvm.memcpy.p0.p0.i64(ptr {{.+}} %[[shareds]], ptr {{.+}}, i64 16, i1 false)
     // CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
     omp.task {
       %z = llvm.add %x, %y : i32
@@ -2334,7 +2335,8 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
 
 
 // CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}, ptr %[[task_data:.+]]) {
-// CHECK:   call void @[[outlined_fn]](ptr %[[task_data]])
+// CHECK:  %[[shareds:.+]] = load ptr, ptr %1, align 8
+// CHECK:   call void @[[outlined_fn]](ptr %[[shareds]])
 // CHECK:   ret i32 0
 // CHECK: }
 
@@ -2430,7 +2432,7 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
 // CHECK:         br label %[[codeRepl:[^,]+]]
 // CHECK:       [[codeRepl]]:
 // CHECK:         %[[omp_global_thread_num_t1:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK:         %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 0, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper)
+// CHECK:         %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper)
 // CHECK:         %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], ptr %[[t1_alloc]])
 // CHECK:         br label %[[task_exit:[^,]+]]
 // CHECK:       [[task_exit]]:
@@ -2443,8 +2445,9 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
 // CHECK:         %[[gep3:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 2
 // CHECK:         store ptr %[[zaddr]], ptr %[[gep3]], align 8
 // CHECK:         %[[omp_global_thread_num_t2:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK:         %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 16, i64 0, ptr @omp_taskgroup_task..omp_par.1.wrapper)
-// CHECK:         call void @llvm.memcpy.p0.p0.i64(ptr align 8 %[[t2_alloc]], ptr align 8 %[[structArg]], i64 16, i1 false)
+// CHECK:         %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @omp_taskgroup_task..omp_par.1.wrapper)
+// CHECK:         %[[shareds:.+]] = load ptr, ptr %[[t2_alloc]]
+// CHECK:         call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[shareds]], ptr align 1 %[[structArg]], i64 16, i1 false)
 // CHECK:         %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], ptr %[[t2_alloc]])
 // CHECK:         br label %[[task_exit3:[^,]+]]
 // CHECK:       [[task_exit3]]:
@@ -2614,7 +2617,7 @@ llvm.func @omp_task_final(%boolexpr: i1) {
 // CHECK:         %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
 // CHECK:         %[[final_flag:.+]] = select i1 %[[boolexpr]], i32 2, i32 0
 // CHECK:         %[[task_flags:.+]] = or i32 %[[final_flag]], 1
-// CHECK:         %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 0, i64 0, ptr @omp_task_final..omp_par.wrapper)
+// CHECK:         %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 40, i64 0, ptr @omp_task_final..omp_par.wrapper)
 // CHECK:         %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
 // CHECK:         br label %[[task_exit:[^,]+]]
 // CHECK:       [[task_exit]]:
@@ -2645,7 +2648,7 @@ llvm.func @omp_task_if(%boolexpr: i1) {
 // CHECK:         br label %[[codeRepl:[^,]+]]
 // CHECK:       [[codeRepl]]:
 // CHECK:         %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK:         %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 0, i64 0, ptr @omp_task_if..omp_par.wrapper)
+// CHECK:         %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 0, ptr @omp_task_if..omp_par.wrapper)
 // CHECK:         br i1 %[[boolexpr]], label %[[true_label:[^,]+]], label %[[false_label:[^,]+]]
 // CHECK:       [[true_label]]:
 // CHECK:         %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
        
    
    
More information about the llvm-commits
mailing list