[clang] 6149589 - [OMPIRBuilder] Support depend clause for task
Prabhdeep Singh Soni via cfe-commits
cfe-commits at lists.llvm.org
Wed Oct 19 10:11:56 PDT 2022
Author: Prabhdeep Singh Soni
Date: 2022-10-19T13:11:43-04:00
New Revision: 614958912784a13737720de39b2da40fe6f26e75
URL: https://github.com/llvm/llvm-project/commit/614958912784a13737720de39b2da40fe6f26e75
DIFF: https://github.com/llvm/llvm-project/commit/614958912784a13737720de39b2da40fe6f26e75.diff
LOG: [OMPIRBuilder] Support depend clause for task
This patch adds support for the `depend` clause for the `task`
construct.
Reviewed By: jdoerfert
Differential Revision: https://reviews.llvm.org/D135695
Added:
Modified:
clang/lib/CodeGen/CGOpenMPRuntime.cpp
llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Removed:
################################################################################
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 84536363d6053..75709740e39ce 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -4377,39 +4377,26 @@ CGOpenMPRuntime::emitTaskInit(CodeGenFunction &CGF, SourceLocation Loc,
return Result;
}
-namespace {
-/// Dependence kind for RTL.
-enum RTLDependenceKindTy {
- DepIn = 0x01,
- DepInOut = 0x3,
- DepMutexInOutSet = 0x4,
- DepInOutSet = 0x8,
- DepOmpAllMem = 0x80,
-};
-/// Fields ids in kmp_depend_info record.
-enum RTLDependInfoFieldsTy { BaseAddr, Len, Flags };
-} // namespace
-
/// Translates internal dependency kind into the runtime kind.
static RTLDependenceKindTy translateDependencyKind(OpenMPDependClauseKind K) {
RTLDependenceKindTy DepKind;
switch (K) {
case OMPC_DEPEND_in:
- DepKind = DepIn;
+ DepKind = RTLDependenceKindTy::DepIn;
break;
// Out and InOut dependencies must use the same code.
case OMPC_DEPEND_out:
case OMPC_DEPEND_inout:
- DepKind = DepInOut;
+ DepKind = RTLDependenceKindTy::DepInOut;
break;
case OMPC_DEPEND_mutexinoutset:
- DepKind = DepMutexInOutSet;
+ DepKind = RTLDependenceKindTy::DepMutexInOutSet;
break;
case OMPC_DEPEND_inoutset:
- DepKind = DepInOutSet;
+ DepKind = RTLDependenceKindTy::DepInOutSet;
break;
case OMPC_DEPEND_outallmemory:
- DepKind = DepOmpAllMem;
+ DepKind = RTLDependenceKindTy::DepOmpAllMem;
break;
case OMPC_DEPEND_source:
case OMPC_DEPEND_sink:
@@ -4457,7 +4444,9 @@ CGOpenMPRuntime::getDepobjElements(CodeGenFunction &CGF, LValue DepobjLVal,
DepObjAddr, KmpDependInfoTy, Base.getBaseInfo(), Base.getTBAAInfo());
// NumDeps = deps[i].base_addr;
LValue BaseAddrLVal = CGF.EmitLValueForField(
- NumDepsBase, *std::next(KmpDependInfoRD->field_begin(), BaseAddr));
+ NumDepsBase,
+ *std::next(KmpDependInfoRD->field_begin(),
+ static_cast<unsigned int>(RTLDependInfoFields::BaseAddr)));
llvm::Value *NumDeps = CGF.EmitLoadOfScalar(BaseAddrLVal, Loc);
return std::make_pair(NumDeps, Base);
}
@@ -4503,18 +4492,24 @@ static void emitDependData(CodeGenFunction &CGF, QualType &KmpDependInfoTy,
}
// deps[i].base_addr = &<Dependencies[i].second>;
LValue BaseAddrLVal = CGF.EmitLValueForField(
- Base, *std::next(KmpDependInfoRD->field_begin(), BaseAddr));
+ Base,
+ *std::next(KmpDependInfoRD->field_begin(),
+ static_cast<unsigned int>(RTLDependInfoFields::BaseAddr)));
CGF.EmitStoreOfScalar(Addr, BaseAddrLVal);
// deps[i].len = sizeof(<Dependencies[i].second>);
LValue LenLVal = CGF.EmitLValueForField(
- Base, *std::next(KmpDependInfoRD->field_begin(), Len));
+ Base, *std::next(KmpDependInfoRD->field_begin(),
+ static_cast<unsigned int>(RTLDependInfoFields::Len)));
CGF.EmitStoreOfScalar(Size, LenLVal);
// deps[i].flags = <Dependencies[i].first>;
RTLDependenceKindTy DepKind = translateDependencyKind(Data.DepKind);
LValue FlagsLVal = CGF.EmitLValueForField(
- Base, *std::next(KmpDependInfoRD->field_begin(), Flags));
- CGF.EmitStoreOfScalar(llvm::ConstantInt::get(LLVMFlagsTy, DepKind),
- FlagsLVal);
+ Base,
+ *std::next(KmpDependInfoRD->field_begin(),
+ static_cast<unsigned int>(RTLDependInfoFields::Flags)));
+ CGF.EmitStoreOfScalar(
+ llvm::ConstantInt::get(LLVMFlagsTy, static_cast<unsigned int>(DepKind)),
+ FlagsLVal);
if (unsigned *P = Pos.dyn_cast<unsigned *>()) {
++(*P);
} else {
@@ -4790,7 +4785,9 @@ Address CGOpenMPRuntime::emitDepobjDependClause(
LValue Base = CGF.MakeAddrLValue(DependenciesArray, KmpDependInfoTy);
// deps[i].base_addr = NumDependencies;
LValue BaseAddrLVal = CGF.EmitLValueForField(
- Base, *std::next(KmpDependInfoRD->field_begin(), BaseAddr));
+ Base,
+ *std::next(KmpDependInfoRD->field_begin(),
+ static_cast<unsigned int>(RTLDependInfoFields::BaseAddr)));
CGF.EmitStoreOfScalar(NumDepsVal, BaseAddrLVal);
llvm::PointerUnion<unsigned *, LValue *> Pos;
unsigned Idx = 1;
@@ -4870,9 +4867,11 @@ void CGOpenMPRuntime::emitUpdateClause(CodeGenFunction &CGF, LValue DepobjLVal,
// deps[i].flags = NewDepKind;
RTLDependenceKindTy DepKind = translateDependencyKind(NewDepKind);
LValue FlagsLVal = CGF.EmitLValueForField(
- Base, *std::next(KmpDependInfoRD->field_begin(), Flags));
- CGF.EmitStoreOfScalar(llvm::ConstantInt::get(LLVMFlagsTy, DepKind),
- FlagsLVal);
+ Base, *std::next(KmpDependInfoRD->field_begin(),
+ static_cast<unsigned int>(RTLDependInfoFields::Flags)));
+ CGF.EmitStoreOfScalar(
+ llvm::ConstantInt::get(LLVMFlagsTy, static_cast<unsigned int>(DepKind)),
+ FlagsLVal);
// Shift the address forward by one element.
Address ElementNext =
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
index 76104f6bc9cfc..b0e9c53e4dabe 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
@@ -207,6 +207,19 @@ enum class OMPInteropType { Unknown, Target, TargetSync };
/// Atomic compare operations. Currently OpenMP only supports ==, >, and <.
enum class OMPAtomicCompareOp : unsigned { EQ, MIN, MAX };
+/// Fields ids in kmp_depend_info record.
+enum class RTLDependInfoFields { BaseAddr, Len, Flags };
+
+/// Dependence kind for RTL.
+enum class RTLDependenceKindTy {
+ DepUnknown = 0x0,
+ DepIn = 0x01,
+ DepInOut = 0x3,
+ DepMutexInOutSet = 0x4,
+ DepInOutSet = 0x8,
+ DepOmpAllMem = 0x80,
+};
+
} // end namespace omp
} // end namespace llvm
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index c16230facd7b4..c59adc775dc67 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -645,6 +645,17 @@ class OpenMPIRBuilder {
/// \param Loc The location where the taskyield directive was encountered.
void createTaskyield(const LocationDescription &Loc);
+ /// A struct to pack the relevant information for an OpenMP depend clause.
+ struct DependData {
+ omp::RTLDependenceKindTy DepKind = omp::RTLDependenceKindTy::DepUnknown;
+ Type *DepValueType;
+ Value *DepVal;
+ explicit DependData() = default;
+ DependData(omp::RTLDependenceKindTy DepKind, Type *DepValueType,
+ Value *DepVal)
+ : DepKind(DepKind), DepValueType(DepValueType), DepVal(DepVal) {}
+ };
+
/// Generator for `#omp task`
///
/// \param Loc The location where the task construct was encountered.
@@ -662,7 +673,8 @@ class OpenMPIRBuilder {
InsertPointTy createTask(const LocationDescription &Loc,
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
bool Tied = true, Value *Final = nullptr,
- Value *IfCondition = nullptr);
+ Value *IfCondition = nullptr,
+ ArrayRef<DependData *> Dependencies = {});
/// Generator for the taskgroup construct
///
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
index 6e7c0d3386849..71abc8822730a 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -92,6 +92,7 @@ __OMP_STRUCT_TYPE(OffloadEntry, __tgt_offload_entry, Int8Ptr, Int8Ptr, SizeTy,
__OMP_STRUCT_TYPE(KernelArgs, __tgt_kernel_arguments, Int32, Int32, VoidPtrPtr,
VoidPtrPtr, Int64Ptr, Int64Ptr, VoidPtrPtr, VoidPtrPtr, Int64)
__OMP_STRUCT_TYPE(AsyncInfo, __tgt_async_info, Int8Ptr)
+__OMP_STRUCT_TYPE(DependInfo, kmp_dep_info, SizeTy, SizeTy, Int8)
#undef __OMP_STRUCT_TYPE
#undef OMP_STRUCT_TYPE
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index adc531620ec93..91bd2fe0726c5 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1290,7 +1290,8 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
OpenMPIRBuilder::InsertPointTy
OpenMPIRBuilder::createTask(const LocationDescription &Loc,
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
- bool Tied, Value *Final, Value *IfCondition) {
+ bool Tied, Value *Final, Value *IfCondition,
+ ArrayRef<DependData *> Dependencies) {
if (!updateToLocation(Loc))
return InsertPointTy();
@@ -1322,8 +1323,8 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
OI.EntryBB = TaskAllocaBB;
OI.OuterAllocaBB = AllocaIP.getBlock();
OI.ExitBB = TaskExitBB;
- OI.PostOutlineCB = [this, Ident, Tied, Final,
- IfCondition](Function &OutlinedFn) {
+ OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition,
+ Dependencies](Function &OutlinedFn) {
// The input IR here looks like the following-
// ```
// func @current_fn() {
@@ -1433,6 +1434,49 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
TaskSize);
}
+ Value *DepArrayPtr = nullptr;
+ if (Dependencies.size()) {
+ InsertPointTy OldIP = Builder.saveIP();
+ Builder.SetInsertPoint(
+ &OldIP.getBlock()->getParent()->getEntryBlock().back());
+
+ Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
+ Value *DepArray =
+ Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
+
+ unsigned P = 0;
+ for (DependData *Dep : Dependencies) {
+ Value *Base =
+ Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P);
+ // Store the pointer to the variable
+ Value *Addr = Builder.CreateStructGEP(
+ DependInfo, Base,
+ static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
+ Value *DepValPtr =
+ Builder.CreatePtrToInt(Dep->DepVal, Builder.getInt64Ty());
+ Builder.CreateStore(DepValPtr, Addr);
+ // Store the size of the variable
+ Value *Size = Builder.CreateStructGEP(
+ DependInfo, Base,
+ static_cast<unsigned int>(RTLDependInfoFields::Len));
+ Builder.CreateStore(Builder.getInt64(M.getDataLayout().getTypeStoreSize(
+ Dep->DepValueType)),
+ Size);
+ // Store the dependency kind
+ Value *Flags = Builder.CreateStructGEP(
+ DependInfo, Base,
+ static_cast<unsigned int>(RTLDependInfoFields::Flags));
+ Builder.CreateStore(
+ ConstantInt::get(Builder.getInt8Ty(),
+ static_cast<unsigned int>(Dep->DepKind)),
+ Flags);
+ ++P;
+ }
+
+ DepArrayPtr = Builder.CreateBitCast(DepArray, Builder.getInt8PtrTy());
+ Builder.restoreIP(OldIP);
+ }
+
// In the presence of the `if` clause, the following IR is generated:
// ...
// %data = call @__kmpc_omp_task_alloc(...)
@@ -1471,9 +1515,21 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, NewTaskData});
Builder.SetInsertPoint(ThenTI);
}
- // Emit the @__kmpc_omp_task runtime call to spawn the task
- Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
- Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});
+
+ if (Dependencies.size()) {
+ Function *TaskFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
+ Builder.CreateCall(
+ TaskFn,
+ {Ident, ThreadID, NewTaskData, Builder.getInt32(Dependencies.size()),
+ DepArrayPtr, ConstantInt::get(Builder.getInt32Ty(), 0),
+ ConstantPointerNull::get(Type::getInt8PtrTy(M.getContext()))});
+
+ } else {
+ // Emit the @__kmpc_omp_task runtime call to spawn the task
+ Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
+ Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});
+ }
StaleCI->eraseFromParent();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index af96ac2c0dd2b..7ae13a51d3a06 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5092,6 +5092,81 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) {
EXPECT_FALSE(verifyModule(*M, &errs()));
}
+TEST_F(OpenMPIRBuilderTest, CreateTaskDepend) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
+ BasicBlock *AllocaBB = Builder.GetInsertBlock();
+ BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+ OpenMPIRBuilder::LocationDescription Loc(
+ InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
+ AllocaInst *InDep = Builder.CreateAlloca(Type::getInt32Ty(M->getContext()));
+ OpenMPIRBuilder::DependData DDIn(RTLDependenceKindTy::DepIn,
+ Type::getInt32Ty(M->getContext()), InDep);
+ SmallVector<OpenMPIRBuilder::DependData *, 4> DDS;
+ DDS.push_back(&DDIn);
+ Builder.restoreIP(OMPBuilder.createTask(
+ Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB,
+ /*Tied=*/false, /*Final*/ nullptr, /*IfCondition*/ nullptr, DDS));
+ OMPBuilder.finalize();
+ Builder.CreateRetVoid();
+
+ // Check for the `NumDeps` argument
+ CallInst *TaskAllocCall = dyn_cast<CallInst>(
+ OMPBuilder
+ .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps)
+ ->user_back());
+ ASSERT_NE(TaskAllocCall, nullptr);
+ ConstantInt *NumDeps = dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(3));
+ ASSERT_NE(NumDeps, nullptr);
+ EXPECT_EQ(NumDeps->getZExtValue(), 1U);
+
+ // Check for the `DepInfo` array argument
+ BitCastInst *DepArrayPtr =
+ dyn_cast<BitCastInst>(TaskAllocCall->getOperand(4));
+ ASSERT_NE(DepArrayPtr, nullptr);
+ AllocaInst *DepArray = dyn_cast<AllocaInst>(DepArrayPtr->getOperand(0));
+ ASSERT_NE(DepArray, nullptr);
+ Value::user_iterator DepArrayI = DepArray->user_begin();
+ EXPECT_EQ(*DepArrayI, DepArrayPtr);
+ ++DepArrayI;
+ Value::user_iterator DepInfoI = DepArrayI->user_begin();
+ // Check for the `DependKind` flag in the `DepInfo` array
+ Value *Flag = findStoredValue<GetElementPtrInst>(*DepInfoI);
+ ASSERT_NE(Flag, nullptr);
+ ConstantInt *FlagInt = dyn_cast<ConstantInt>(Flag);
+ ASSERT_NE(FlagInt, nullptr);
+ EXPECT_EQ(FlagInt->getZExtValue(),
+ static_cast<unsigned int>(RTLDependenceKindTy::DepIn));
+ ++DepInfoI;
+ // Check for the size in the `DepInfo` array
+ Value *Size = findStoredValue<GetElementPtrInst>(*DepInfoI);
+ ASSERT_NE(Size, nullptr);
+ ConstantInt *SizeInt = dyn_cast<ConstantInt>(Size);
+ ASSERT_NE(SizeInt, nullptr);
+ EXPECT_EQ(SizeInt->getZExtValue(), 4U);
+ ++DepInfoI;
+ // Check for the variable address in the `DepInfo` array
+ Value *AddrStored = findStoredValue<GetElementPtrInst>(*DepInfoI);
+ ASSERT_NE(AddrStored, nullptr);
+ PtrToIntInst *AddrInt = dyn_cast<PtrToIntInst>(AddrStored);
+ ASSERT_NE(AddrInt, nullptr);
+ Value *Addr = AddrInt->getPointerOperand();
+ EXPECT_EQ(Addr, InDep);
+
+ ConstantInt *NumDepsNoAlias =
+ dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(5));
+ ASSERT_NE(NumDepsNoAlias, nullptr);
+ EXPECT_EQ(NumDepsNoAlias->getZExtValue(), 0U);
+ EXPECT_EQ(TaskAllocCall->getOperand(6),
+ ConstantPointerNull::get(Type::getInt8PtrTy(M->getContext())));
+
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+}
+
TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
More information about the cfe-commits
mailing list