[Mlir-commits] [mlir] 9b57b16 - [OMPIRBuilder] Fix shared clause for task construct
Prabhdeep Singh Soni
llvmlistbot at llvm.org
Fri Sep 15 09:20:04 PDT 2023
Author: Prabhdeep Singh Soni
Date: 2023-09-15T12:19:47-04:00
New Revision: 9b57b167bb4d849b6803e28f638b970f493511f9
URL: https://github.com/llvm/llvm-project/commit/9b57b167bb4d849b6803e28f638b970f493511f9
DIFF: https://github.com/llvm/llvm-project/commit/9b57b167bb4d849b6803e28f638b970f493511f9.diff
LOG: [OMPIRBuilder] Fix shared clause for task construct
This patch fixes the shared clause for the task construct with multiple
shared variables. The shareds field in the kmp_task_t is not an inline
array in the struct, rather it is a pointer to an array. With an inline
array, the pointer dereference to the outlined function body of the task
would segmentation fault when accessed by the runtime.
Reviewed By: kiranchandramohan, jdoerfert
Differential Revision: https://reviews.llvm.org/D158462
Added:
Modified:
llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
mlir/test/Target/LLVMIR/openmp-llvm.mlir
Removed:
################################################################################
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
index c4218326280b2ba..176b883fe68f7ad 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -95,6 +95,7 @@ __OMP_STRUCT_TYPE(KernelArgs, __tgt_kernel_arguments, false, Int32, Int32, VoidP
Int64, Int64, Int32Arr3Ty, Int32Arr3Ty, Int32)
__OMP_STRUCT_TYPE(AsyncInfo, __tgt_async_info, false, Int8Ptr)
__OMP_STRUCT_TYPE(DependInfo, kmp_dep_info, false, SizeTy, SizeTy, Int8)
+__OMP_STRUCT_TYPE(Task, kmp_task_ompbuilder_t, false, VoidPtr, VoidPtr, Int32, VoidPtr, VoidPtr)
__OMP_STRUCT_TYPE(ConfigurationEnvironment, ConfigurationEnvironmentTy, false,
Int8, Int8, Int8)
__OMP_STRUCT_TYPE(DynamicEnvironment, DynamicEnvironmentTy, false, Int16)
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 8e9200aa1821985..1ace7d5b97ffc96 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1555,9 +1555,9 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
"there must be a single user for the outlined function");
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
- // HasTaskData is true if any variables are captured in the outlined region,
+ // HasShareds is true if any variables are captured in the outlined region,
// false otherwise.
- bool HasTaskData = StaleCI->arg_size() > 0;
+ bool HasShareds = StaleCI->arg_size() > 0;
Builder.SetInsertPoint(StaleCI);
// Gather the arguments for emitting the runtime call for
@@ -1585,8 +1585,15 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
// Argument - `sizeof_kmp_task_t` (TaskSize)
// Tasksize refers to the size in bytes of kmp_task_t data structure
// including private vars accessed in task.
- Value *TaskSize = Builder.getInt64(0);
- if (HasTaskData) {
+ // TODO: add kmp_task_t_with_privates (privates)
+ Value *TaskSize = Builder.getInt64(
+ divideCeil(M.getDataLayout().getTypeSizeInBits(Task), 8));
+
+ // Argument - `sizeof_shareds` (SharedsSize)
+ // SharedsSize refers to the shareds array size in the kmp_task_t data
+ // structure.
+ Value *SharedsSize = Builder.getInt64(0);
+ if (HasShareds) {
AllocaInst *ArgStructAlloca =
dyn_cast<AllocaInst>(StaleCI->getArgOperand(0));
assert(ArgStructAlloca &&
@@ -1596,19 +1603,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
assert(ArgStructType && "Unable to find struct type corresponding to "
"arguments for extracted function");
- TaskSize =
+ SharedsSize =
Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
}
- // TODO: Argument - sizeof_shareds
-
// Argument - task_entry (the wrapper function)
- // If the outlined function has some captured variables (i.e. HasTaskData is
+ // If the outlined function has some captured variables (i.e. HasShareds is
// true), then the wrapper function will have an additional argument (the
// struct containing captured variables). Otherwise, no such argument will
// be present.
SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty()};
- if (HasTaskData)
+ if (HasShareds)
WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType());
FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
(Twine(OutlinedFn.getName()) + ".wrapper").str(),
@@ -1617,19 +1622,19 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
// Emit the @__kmpc_omp_task_alloc runtime call
// The runtime call returns a pointer to an area where the task captured
- // variables must be copied before the task is run (NewTaskData)
- CallInst *NewTaskData = Builder.CreateCall(
- TaskAllocFn,
- {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
- /*sizeof_task=*/TaskSize, /*sizeof_shared=*/Builder.getInt64(0),
- /*task_func=*/WrapperFunc});
+ // variables must be copied before the task is run (TaskData)
+ CallInst *TaskData = Builder.CreateCall(
+ TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
+ /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
+ /*task_func=*/WrapperFunc});
// Copy the arguments for outlined function
- if (HasTaskData) {
- Value *TaskData = StaleCI->getArgOperand(0);
+ if (HasShareds) {
+ Value *Shareds = StaleCI->getArgOperand(0);
Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
- Builder.CreateMemCpy(NewTaskData, Alignment, TaskData, Alignment,
- TaskSize);
+ Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
+ Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
+ SharedsSize);
}
Value *DepArrayPtr = nullptr;
@@ -1705,12 +1710,12 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
Function *TaskCompleteFn =
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
- Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, NewTaskData});
- if (HasTaskData)
- Builder.CreateCall(WrapperFunc, {ThreadID, NewTaskData});
+ Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
+ if (HasShareds)
+ Builder.CreateCall(WrapperFunc, {ThreadID, TaskData});
else
Builder.CreateCall(WrapperFunc, {ThreadID});
- Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, NewTaskData});
+ Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
Builder.SetInsertPoint(ThenTI);
}
@@ -1719,14 +1724,14 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
Builder.CreateCall(
TaskFn,
- {Ident, ThreadID, NewTaskData, Builder.getInt32(Dependencies.size()),
+ {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
DepArrayPtr, ConstantInt::get(Builder.getInt32Ty(), 0),
ConstantPointerNull::get(Type::getInt8PtrTy(M.getContext()))});
} else {
// Emit the @__kmpc_omp_task runtime call to spawn the task
Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
- Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});
+ Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
}
StaleCI->eraseFromParent();
@@ -1735,10 +1740,13 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
BasicBlock *WrapperEntryBB =
BasicBlock::Create(M.getContext(), "", WrapperFunc);
Builder.SetInsertPoint(WrapperEntryBB);
- if (HasTaskData)
- Builder.CreateCall(&OutlinedFn, {WrapperFunc->getArg(1)});
- else
+ if (HasShareds) {
+ llvm::Value *Shareds =
+ Builder.CreateLoad(VoidPtr, WrapperFunc->getArg(1));
+ Builder.CreateCall(&OutlinedFn, {Shareds});
+ } else {
Builder.CreateCall(&OutlinedFn);
+ }
Builder.CreateRet(Builder.getInt32(0));
};
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index bc1687cae25f1d9..2026824416f3e3c 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5397,19 +5397,29 @@ TEST_F(OpenMPIRBuilderTest, CreateTask) {
ConstantInt *DataSize =
dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(3));
ASSERT_NE(DataSize, nullptr);
- EXPECT_EQ(DataSize->getSExtValue(), 24); // 64-bit pointer + 128-bit integer
+ EXPECT_EQ(DataSize->getSExtValue(), 40);
- // TODO: Verify size of shared clause variables
+ ConstantInt *SharedsSize =
+ dyn_cast<ConstantInt>(TaskAllocCall->getOperand(4));
+ EXPECT_EQ(SharedsSize->getSExtValue(),
+ 24); // 64-bit pointer + 128-bit integer
// Verify Wrapper function
Function *WrapperFunc =
dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
ASSERT_NE(WrapperFunc, nullptr);
+
+ LoadInst *SharedsLoad = dyn_cast<LoadInst>(WrapperFunc->begin()->begin());
+ ASSERT_NE(SharedsLoad, nullptr);
+ EXPECT_EQ(SharedsLoad->getPointerOperand(), WrapperFunc->getArg(1));
+
EXPECT_FALSE(WrapperFunc->isDeclaration());
- CallInst *OutlinedFnCall = dyn_cast<CallInst>(WrapperFunc->begin()->begin());
+ CallInst *OutlinedFnCall =
+ dyn_cast<CallInst>(++WrapperFunc->begin()->begin());
ASSERT_NE(OutlinedFnCall, nullptr);
EXPECT_EQ(WrapperFunc->getArg(0)->getType(), Builder.getInt32Ty());
- EXPECT_EQ(OutlinedFnCall->getArgOperand(0), WrapperFunc->getArg(1));
+ EXPECT_EQ(OutlinedFnCall->getArgOperand(0),
+ WrapperFunc->getArg(1)->uses().begin()->getUser());
// Verify the presence of `trunc` and `icmp` instructions in Outlined function
Function *OutlinedFn = OutlinedFnCall->getCalledFunction();
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 1dff33cdc22ce12..28b0113a19d61b8 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2208,7 +2208,7 @@ llvm.mlir.global internal @_QFsubEx() : i32
llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
- // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 0,
+ // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40,
// CHECK-SAME: i64 0, ptr @[[wrapper_fn:.+]])
// CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
omp.task {
@@ -2258,7 +2258,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
- // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 0,
+ // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40,
// CHECK-SAME: i64 0, ptr @[[wrapper_fn:.+]])
// CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]], {{.*}})
omp.task depend(taskdependin -> %zaddr : !llvm.ptr<i32>) {
@@ -2303,9 +2303,10 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
llvm.store %
diff , %zaddr : !llvm.ptr<i32>
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
- // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 16, i64 0,
+ // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 16,
// CHECK-SAME: ptr @[[wrapper_fn:.+]])
- // CHECK: call void @llvm.memcpy.p0.p0.i64(ptr {{.+}} %[[task_data]], ptr {{.+}}, i64 16, i1 false)
+ // CHECK: %[[shareds:.+]] = load ptr, ptr %[[task_data]]
+ // CHECK: call void @llvm.memcpy.p0.p0.i64(ptr {{.+}} %[[shareds]], ptr {{.+}}, i64 16, i1 false)
// CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
omp.task {
%z = llvm.add %x, %y : i32
@@ -2334,7 +2335,8 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}, ptr %[[task_data:.+]]) {
-// CHECK: call void @[[outlined_fn]](ptr %[[task_data]])
+// CHECK: %[[shareds:.+]] = load ptr, ptr %1, align 8
+// CHECK: call void @[[outlined_fn]](ptr %[[shareds]])
// CHECK: ret i32 0
// CHECK: }
@@ -2430,7 +2432,7 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
// CHECK: br label %[[codeRepl:[^,]+]]
// CHECK: [[codeRepl]]:
// CHECK: %[[omp_global_thread_num_t1:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK: %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 0, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper)
+// CHECK: %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper)
// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], ptr %[[t1_alloc]])
// CHECK: br label %[[task_exit:[^,]+]]
// CHECK: [[task_exit]]:
@@ -2443,8 +2445,9 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
// CHECK: %[[gep3:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 2
// CHECK: store ptr %[[zaddr]], ptr %[[gep3]], align 8
// CHECK: %[[omp_global_thread_num_t2:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK: %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 16, i64 0, ptr @omp_taskgroup_task..omp_par.1.wrapper)
-// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 8 %[[t2_alloc]], ptr align 8 %[[structArg]], i64 16, i1 false)
+// CHECK: %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @omp_taskgroup_task..omp_par.1.wrapper)
+// CHECK: %[[shareds:.+]] = load ptr, ptr %[[t2_alloc]]
+// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[shareds]], ptr align 1 %[[structArg]], i64 16, i1 false)
// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], ptr %[[t2_alloc]])
// CHECK: br label %[[task_exit3:[^,]+]]
// CHECK: [[task_exit3]]:
@@ -2614,7 +2617,7 @@ llvm.func @omp_task_final(%boolexpr: i1) {
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
// CHECK: %[[final_flag:.+]] = select i1 %[[boolexpr]], i32 2, i32 0
// CHECK: %[[task_flags:.+]] = or i32 %[[final_flag]], 1
-// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 0, i64 0, ptr @omp_task_final..omp_par.wrapper)
+// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 40, i64 0, ptr @omp_task_final..omp_par.wrapper)
// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
// CHECK: br label %[[task_exit:[^,]+]]
// CHECK: [[task_exit]]:
@@ -2645,7 +2648,7 @@ llvm.func @omp_task_if(%boolexpr: i1) {
// CHECK: br label %[[codeRepl:[^,]+]]
// CHECK: [[codeRepl]]:
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 0, i64 0, ptr @omp_task_if..omp_par.wrapper)
+// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 0, ptr @omp_task_if..omp_par.wrapper)
// CHECK: br i1 %[[boolexpr]], label %[[true_label:[^,]+]], label %[[false_label:[^,]+]]
// CHECK: [[true_label]]:
// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
More information about the Mlir-commits
mailing list