[llvm-branch-commits] [llvm] [mlir] [OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode (PR #116051)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Nov 27 04:26:12 PST 2024
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/116051
>From f120456cd3200ff82cca63570272d57ba909fe87 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 27 Nov 2024 11:33:01 +0000
Subject: [PATCH 1/2] [OMPIRBuilder] Support runtime number of teams and
threads, and SPMD mode
This patch introduces a `TargetKernelRuntimeAttrs` structure to hold
host-evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values
passed to the runtime kernel offloading call.
Additionally, `createTarget` is extended to take an `IsSPMD` flag, used to
influence target device code generation.
---
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 26 +-
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 125 +++++++--
.../Frontend/OpenMPIRBuilderTest.cpp | 256 +++++++++++++++++-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 10 +-
4 files changed, 383 insertions(+), 34 deletions(-)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index da450ef5adbc14..a85f41e586c514 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2237,6 +2237,26 @@ class OpenMPIRBuilder {
int32_t MinThreads = 1;
};
+ /// Container to pass LLVM IR runtime values or constants related to the
+ /// number of teams and threads with which the kernel must be launched, as
+ /// well as the trip count of the SPMD loop, if it is an SPMD kernel. These
+ /// must be defined in the host prior to the call to the kernel launch OpenMP
+ /// RTL function.
+ struct TargetKernelRuntimeAttrs {
+ SmallVector<Value *, 3> MaxTeams = {nullptr};
+ Value *MinTeams = nullptr;
+ SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
+ SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
+
+ /// 'parallel' construct 'num_threads' clause value, if present and it is a
+ /// target SPMD kernel.
+ Value *MaxThreads = nullptr;
+
+ /// Total number of iterations of the target SPMD kernel or null if it is a
+ /// generic kernel.
+ Value *LoopTripCount = nullptr;
+ };
+
/// Data structure that contains the needed information to construct the
/// kernel args vector.
struct TargetKernelArgs {
@@ -2905,11 +2925,14 @@ class OpenMPIRBuilder {
///
/// \param Loc where the target data construct was encountered.
/// \param IsOffloadEntry whether it is an offload entry.
+ /// \param IsSPMD whether it is a target SPMD kernel.
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
/// \param DefaultAttrs Structure containing the default numbers of threads
/// and teams to launch the kernel with.
+ /// \param RuntimeAttrs Structure containing the runtime numbers of threads
+ /// and teams to launch the kernel with.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
/// \param BodyGenCB Callback that will generate the region code.
@@ -2919,11 +2942,12 @@ class OpenMPIRBuilder {
// dependency information as passed in the depend clause
// \param HasNowait Whether the target construct has a `nowait` clause or not.
InsertPointOrErrorTy createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry,
+ const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
TargetBodyGenCallbackTy BodyGenCB,
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 302d363965c940..09f794ccf734b3 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6727,8 +6727,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
}
+static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
+ Module &M) {
+ if (List.empty())
+ return;
+
+ Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
+
+ // Convert List to what ConstantArray needs.
+ SmallVector<Constant *, 8> UsedArray;
+ UsedArray.reserve(List.size());
+ for (auto Item : List)
+ UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+ cast<Constant>(&*Item), PtrTy));
+
+ ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
+ auto *GV =
+ new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
+ llvm::ConstantArray::get(ArrTy, UsedArray), Name);
+
+ GV->setSection("llvm.metadata");
+}
+
+static void
+emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+ StringRef FunctionName, OMPTgtExecModeFlags Mode,
+ std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
+ auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
+ auto *GVMode = new llvm::GlobalVariable(
+ OMPBuilder.M, Int8Ty, /*isConstant=*/true,
+ llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
+ Twine(FunctionName, "_exec_mode"));
+ GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
+ LLVMCompilerUsed.emplace_back(GVMode);
+}
+
static Expected<Function *> createOutlinedFunction(
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6758,6 +6793,15 @@ static Expected<Function *> createOutlinedFunction(
auto Func =
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
+ if (OMPBuilder.Config.isTargetDevice()) {
+ std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
+ emitExecutionMode(OMPBuilder, Builder, FuncName,
+ IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
+ : OMP_TGT_EXEC_MODE_GENERIC,
+ LLVMCompilerUsed);
+ emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
+ }
+
// Save insert point.
IRBuilder<>::InsertPointGuard IPG(Builder);
// If there's a DISubprogram associated with current function, then
@@ -6798,7 +6842,7 @@ static Expected<Function *> createOutlinedFunction(
// Insert target init call in the device compilation pass.
if (OMPBuilder.Config.isTargetDevice())
Builder.restoreIP(
- OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
+ OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs));
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
@@ -6995,7 +7039,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
static Error emitTargetOutlinedFunction(
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
- TargetRegionEntryInfo &EntryInfo,
+ bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
Function *&OutlinedFn, Constant *&OutlinedFnID,
SmallVectorImpl<Value *> &Inputs,
@@ -7004,7 +7048,7 @@ static Error emitTargetOutlinedFunction(
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
[&](StringRef EntryFnName) {
- return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+ return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs,
EntryFnName, Inputs, CBFunc,
ArgAccessorFuncCB);
};
@@ -7304,6 +7348,7 @@ static void
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+ const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
Function *OutlinedFn, Constant *OutlinedFnID,
SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7385,11 +7430,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
/*ForEndCall=*/false);
SmallVector<Value *, 3> NumTeamsC;
+ for (auto [DefaultVal, RuntimeVal] :
+ zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+ NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
+
+ // Calculate number of threads: 0 if no clauses specified, otherwise it is the
+ // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+ if (Clause)
+ Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+ /*isSigned=*/false);
+ return Clause;
+ };
+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+ if (Clause)
+ Result = Result
+ ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
+ Result, Clause)
+ : Clause;
+ };
+
+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
SmallVector<Value *, 3> NumThreadsC;
- for (auto V : DefaultAttrs.MaxTeams)
- NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
- for (auto V : DefaultAttrs.MaxThreads)
- NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+ Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
+ ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+ : nullptr;
+
+ for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
+ RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+ Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+ Value *NumThreads = InitMaxThreadsClause(TargetVal);
+
+ CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+ CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+
+ NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+ }
unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
@@ -7398,14 +7475,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);
- // TODO: Use correct NumIterations
- Value *NumIterations = Builder.getInt64(0);
+
+ Value *TripCount = RuntimeAttrs.LoopTripCount
+ ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+ Builder.getInt64Ty(),
+ /*isSigned=*/false)
+ : Builder.getInt64(0);
+
// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);
- KArgs = OpenMPIRBuilder::TargetKernelArgs(
- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
- DynCGGroupMem, HasNoWait);
+ KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+ NumTeamsC, NumThreadsC,
+ DynCGGroupMem, HasNoWait);
// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
@@ -7427,13 +7509,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
}
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
- InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
+ const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
+ InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
+ TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies, bool HasNowait) {
+ assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
+ "trip count not expected if IsSPMD=false");
if (!updateToLocation(Loc))
return InsertPointTy();
@@ -7446,16 +7532,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// the target region itself is generated using the callbacks CBFunc
// and ArgAccessorFuncCB
if (Error Err = emitTargetOutlinedFunction(
- *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
- OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
+ *this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs,
+ OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
return Err;
// If we are not on the target device, then we need to generate code
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
- emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
- OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
+ emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+ OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
+ HasNowait);
return Builder.saveIP();
}
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index b0688d6215e42d..a8c786b5886afe 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6123,7 +6123,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
OMPBuilder.setConfig(Config);
F->setName("func");
IRBuilder<> Builder(BB);
- auto Int32Ty = Builder.getInt32Ty();
+ auto *Int32Ty = Builder.getInt32Ty();
AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr");
AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr");
@@ -6183,11 +6183,15 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
- /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
- OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
- Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
- GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ /*MaxTeams=*/{10}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+ RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20);
+ RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
+ RuntimeAttrs.MaxThreads = Builder.getInt32(40);
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, Builder.saveIP(),
+ Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
+ GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
OMPBuilder.finalize();
@@ -6207,6 +6211,43 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
StringRef FunctionName = KernelLaunchFunc->getName();
EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
+ // Check num_teams and num_threads in call arguments
+ EXPECT_TRUE(Call->arg_size() >= 4);
+ Value *NumTeamsArg = Call->getArgOperand(2);
+ EXPECT_TRUE(isa<ConstantInt>(NumTeamsArg));
+ EXPECT_EQ(10U, cast<ConstantInt>(NumTeamsArg)->getZExtValue());
+ Value *NumThreadsArg = Call->getArgOperand(3);
+ EXPECT_TRUE(isa<ConstantInt>(NumThreadsArg));
+ EXPECT_EQ(20U, cast<ConstantInt>(NumThreadsArg)->getZExtValue());
+
+ // Check num_teams and num_threads kernel arguments (use number 5 starting
+ // from the end and counting the call to __tgt_target_kernel as the first use)
+ Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+ EXPECT_TRUE(KernelArgs->getNumUses() >= 4);
+ Value *NumTeamsGetElemPtr = *std::next(KernelArgs->user_begin(), 3);
+ EXPECT_TRUE(isa<GetElementPtrInst>(NumTeamsGetElemPtr));
+ Value *NumTeamsStore = NumTeamsGetElemPtr->getUniqueUndroppableUser();
+ EXPECT_TRUE(isa<StoreInst>(NumTeamsStore));
+ Value *NumTeamsStoreArg = cast<StoreInst>(NumTeamsStore)->getValueOperand();
+ EXPECT_TRUE(isa<ConstantDataSequential>(NumTeamsStoreArg));
+ auto *NumTeamsStoreValue = cast<ConstantDataSequential>(NumTeamsStoreArg);
+ EXPECT_EQ(3U, NumTeamsStoreValue->getNumElements());
+ EXPECT_EQ(10U, NumTeamsStoreValue->getElementAsInteger(0));
+ EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(1));
+ EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(2));
+ Value *NumThreadsGetElemPtr = *std::next(KernelArgs->user_begin(), 2);
+ EXPECT_TRUE(isa<GetElementPtrInst>(NumThreadsGetElemPtr));
+ Value *NumThreadsStore = NumThreadsGetElemPtr->getUniqueUndroppableUser();
+ EXPECT_TRUE(isa<StoreInst>(NumThreadsStore));
+ Value *NumThreadsStoreArg =
+ cast<StoreInst>(NumThreadsStore)->getValueOperand();
+ EXPECT_TRUE(isa<ConstantDataSequential>(NumThreadsStoreArg));
+ auto *NumThreadsStoreValue = cast<ConstantDataSequential>(NumThreadsStoreArg);
+ EXPECT_EQ(3U, NumThreadsStoreValue->getNumElements());
+ EXPECT_EQ(20U, NumThreadsStoreValue->getElementAsInteger(0));
+ EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(1));
+ EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(2));
+
// Check the fallback call
BasicBlock *FallbackBlock = Branch->getSuccessor(0);
Iter = FallbackBlock->rbegin();
@@ -6297,9 +6338,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
- CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+ EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+ BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
@@ -6378,6 +6421,197 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
auto *ExitBlock = EntryBlockBranch->getSuccessor(1);
EXPECT_EQ(ExitBlock->getName(), "worker.exit");
EXPECT_TRUE(isa<ReturnInst>(ExitBlock->getFirstNonPHI()));
+
+ // Check global exec_mode.
+ GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used");
+ EXPECT_NE(Used, nullptr);
+ Constant *UsedInit = Used->getInitializer();
+ EXPECT_NE(UsedInit, nullptr);
+ EXPECT_TRUE(isa<ConstantArray>(UsedInit));
+ auto *UsedInitData = cast<ConstantArray>(UsedInit);
+ EXPECT_EQ(1U, UsedInitData->getNumOperands());
+ Constant *ExecMode = UsedInitData->getOperand(0);
+ EXPECT_TRUE(isa<GlobalVariable>(ExecMode));
+ Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer();
+ EXPECT_NE(ExecModeValue, nullptr);
+ EXPECT_TRUE(isa<ConstantInt>(ExecModeValue));
+ EXPECT_EQ(OMP_TGT_EXEC_MODE_GENERIC,
+ cast<ConstantInt>(ExecModeValue)->getZExtValue());
+}
+
+TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ OpenMPIRBuilderConfig Config(/*IsTargetDevice=*/false, /*IsGPU=*/false,
+ /*OpenMPOffloadMandatory=*/false,
+ /*HasRequiresReverseOffload=*/false,
+ /*HasRequiresUnifiedAddress=*/false,
+ /*HasRequiresUnifiedSharedMemory=*/false,
+ /*HasRequiresDynamicAllocators=*/false);
+ OMPBuilder.setConfig(Config);
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+
+ auto BodyGenCB = [&](InsertPointTy,
+ InsertPointTy CodeGenIP) -> InsertPointTy {
+ Builder.restoreIP(CodeGenIP);
+ return Builder.saveIP();
+ };
+
+ auto SimpleArgAccessorCB =
+ [&](llvm::Argument &, llvm::Value *, llvm::Value *&,
+ llvm::OpenMPIRBuilder::InsertPointTy,
+ llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ return Builder.saveIP();
+ };
+
+ llvm::SmallVector<llvm::Value *> Inputs;
+ llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
+ auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy)
+ -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; };
+
+ TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
+ OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
+ OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+ RuntimeAttrs.LoopTripCount = Builder.getInt64(1000);
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, Builder.saveIP(),
+ Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
+ GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ assert(AfterIP && "unexpected error");
+ Builder.restoreIP(*AfterIP);
+ OMPBuilder.finalize();
+ Builder.CreateRetVoid();
+
+ // Check the kernel launch sequence
+ auto Iter = F->getEntryBlock().rbegin();
+ EXPECT_TRUE(isa<BranchInst>(&*(Iter)));
+ BranchInst *Branch = dyn_cast<BranchInst>(&*(Iter));
+ EXPECT_TRUE(isa<CmpInst>(&*(++Iter)));
+ EXPECT_TRUE(isa<CallInst>(&*(++Iter)));
+ CallInst *Call = dyn_cast<CallInst>(&*(Iter));
+
+ // Check that the kernel launch function is called
+ Function *KernelLaunchFunc = Call->getCalledFunction();
+ EXPECT_NE(KernelLaunchFunc, nullptr);
+ StringRef FunctionName = KernelLaunchFunc->getName();
+ EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
+
+ // Check the trip count kernel argument (use number 5 starting from the end
+ // and counting the call to __tgt_target_kernel as the first use)
+ Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+ EXPECT_TRUE(KernelArgs->getNumUses() >= 6);
+ Value *TripCountGetElemPtr = *std::next(KernelArgs->user_begin(), 5);
+ EXPECT_TRUE(isa<GetElementPtrInst>(TripCountGetElemPtr));
+ Value *TripCountStore = TripCountGetElemPtr->getUniqueUndroppableUser();
+ EXPECT_TRUE(isa<StoreInst>(TripCountStore));
+ Value *TripCountStoreArg = cast<StoreInst>(TripCountStore)->getValueOperand();
+ EXPECT_TRUE(isa<ConstantInt>(TripCountStoreArg));
+ EXPECT_EQ(1000U, cast<ConstantInt>(TripCountStoreArg)->getZExtValue());
+
+ // Check the fallback call
+ BasicBlock *FallbackBlock = Branch->getSuccessor(0);
+ Iter = FallbackBlock->rbegin();
+ CallInst *FCall = dyn_cast<CallInst>(&*(++Iter));
+ // 'F' has a dummy DISubprogram which causes OutlinedFunc to also
+ // have a DISubprogram. In this case, the call to OutlinedFunc needs
+ // to have a debug loc, otherwise verifier will complain.
+ FCall->setDebugLoc(DL);
+ EXPECT_NE(FCall, nullptr);
+
+ // Check that the outlined function exists with the expected prefix
+ Function *OutlinedFunc = FCall->getCalledFunction();
+ EXPECT_NE(OutlinedFunc, nullptr);
+ StringRef FunctionName2 = OutlinedFunc->getName();
+ EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading"));
+
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+}
+
+TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.setConfig(
+ OpenMPIRBuilderConfig(/*IsTargetDevice=*/true, /*IsGPU=*/false,
+ /*OpenMPOffloadMandatory=*/false,
+ /*HasRequiresReverseOffload=*/false,
+ /*HasRequiresUnifiedAddress=*/false,
+ /*HasRequiresUnifiedSharedMemory=*/false,
+ /*HasRequiresDynamicAllocators=*/false));
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+ OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+
+ Function *OutlinedFn = nullptr;
+ llvm::SmallVector<llvm::Value *> CapturedArgs;
+
+ auto SimpleArgAccessorCB =
+ [&](llvm::Argument &, llvm::Value *, llvm::Value *&,
+ llvm::OpenMPIRBuilder::InsertPointTy,
+ llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ return Builder.saveIP();
+ };
+
+ llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
+ auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy)
+ -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; };
+
+ auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP)
+ -> OpenMPIRBuilder::InsertPointTy {
+ Builder.restoreIP(CodeGenIP);
+ OutlinedFn = CodeGenIP.getBlock()->getParent();
+ return Builder.saveIP();
+ };
+
+ IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(),
+ F->getEntryBlock().getFirstInsertionPt());
+ TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
+ /*Line=*/3, /*Count=*/0);
+
+ OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, EntryIP, EntryIP,
+ EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+ BodyGenCB, SimpleArgAccessorCB);
+ assert(AfterIP && "unexpected error");
+ Builder.restoreIP(*AfterIP);
+
+ Builder.CreateRetVoid();
+ OMPBuilder.finalize();
+
+ // Check outlined function
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+ EXPECT_NE(OutlinedFn, nullptr);
+ EXPECT_NE(F, OutlinedFn);
+
+ EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage());
+ // Account for the "implicit" first argument.
+ EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3");
+ EXPECT_EQ(OutlinedFn->arg_size(), 1U);
+
+ // Check global exec_mode.
+ GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used");
+ EXPECT_NE(Used, nullptr);
+ Constant *UsedInit = Used->getInitializer();
+ EXPECT_NE(UsedInit, nullptr);
+ EXPECT_TRUE(isa<ConstantArray>(UsedInit));
+ auto *UsedInitData = cast<ConstantArray>(UsedInit);
+ EXPECT_EQ(1U, UsedInitData->getNumOperands());
+ Constant *ExecMode = UsedInitData->getOperand(0);
+ EXPECT_TRUE(isa<GlobalVariable>(ExecMode));
+ Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer();
+ EXPECT_NE(ExecModeValue, nullptr);
+ EXPECT_TRUE(isa<ConstantInt>(ExecModeValue));
+ EXPECT_EQ(OMP_TGT_EXEC_MODE_SPMD,
+ cast<ConstantInt>(ExecModeValue)->getZExtValue());
}
TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
@@ -6448,9 +6682,11 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
- CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+ EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+ BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d3c3839accb7e7..9bdf3e11496f3a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3936,9 +3936,11 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
allocaIP, codeGenIP);
};
- // TODO: Populate default attributes based on the construct and clauses.
+ // TODO: Populate default and runtime attributes based on the construct and
+ // clauses.
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
llvm::SmallVector<llvm::Value *, 4> kernelInput;
for (size_t i = 0; i < mapVars.size(); ++i) {
@@ -3957,9 +3959,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
- ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
- defaultAttrs, kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds,
- targetOp.getNowait());
+ ompLoc, isOffloadEntry, /*IsSPMD=*/false, allocaIP, builder.saveIP(),
+ entryInfo, defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB,
+ bodyCB, argAccessorCB, dds, targetOp.getNowait());
if (failed(handleError(afterIP, opInst)))
return failure();
>From e2b3ac479a5333988a2e8dcbaff26d7a87ea623c Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 27 Nov 2024 12:08:46 +0000
Subject: [PATCH 2/2] Address review comments
---
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 10 +-
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 104 +++++++-----------
.../Frontend/OpenMPIRBuilderTest.cpp | 46 ++++----
3 files changed, 68 insertions(+), 92 deletions(-)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index a85f41e586c514..f0ef58ed3b19ba 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1387,9 +1387,6 @@ class OpenMPIRBuilder {
/// Supporting functions for Reductions CodeGen.
private:
- /// Emit the llvm.used metadata.
- void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List);
-
/// Get the id of the current thread on the GPU.
Value *getGPUThreadID();
@@ -2011,6 +2008,13 @@ class OpenMPIRBuilder {
/// Value.
GlobalValue *createGlobalFlag(unsigned Value, StringRef Name);
+ /// Emit the llvm.used metadata.
+ void emitUsed(StringRef Name, ArrayRef<llvm::WeakTrackingVH> List);
+
+ /// Emit the kernel execution mode.
+ GlobalVariable *emitKernelExecutionMode(StringRef KernelName,
+ omp::OMPTgtExecModeFlags Mode);
+
/// Generate control flow and cleanup for cancellation.
///
/// \param CancelFlag Flag indicating if the cancellation is performed.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 09f794ccf734b3..73f221c07af746 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -831,6 +831,38 @@ GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
return GV;
}
+void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
+ if (List.empty())
+ return;
+
+ // Convert List to what ConstantArray needs.
+ SmallVector<Constant *, 8> UsedArray;
+ UsedArray.resize(List.size());
+ for (unsigned I = 0, E = List.size(); I != E; ++I)
+ UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+ cast<Constant>(&*List[I]), Builder.getPtrTy());
+
+ if (UsedArray.empty())
+ return;
+ ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
+
+ auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
+ ConstantArray::get(ATy, UsedArray), Name);
+
+ GV->setSection("llvm.metadata");
+}
+
+GlobalVariable *
+OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
+ OMPTgtExecModeFlags Mode) {
+ auto *Int8Ty = Builder.getInt8Ty();
+ auto *GVMode = new GlobalVariable(
+ M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
+ ConstantInt::get(Int8Ty, Mode), Twine(KernelName, "_exec_mode"));
+ GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
+ return GVMode;
+}
+
Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
uint32_t SrcLocStrSize,
IdentFlag LocFlags,
@@ -2242,28 +2274,6 @@ static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
}
-void OpenMPIRBuilder::emitUsed(StringRef Name,
- std::vector<WeakTrackingVH> &List) {
- if (List.empty())
- return;
-
- // Convert List to what ConstantArray needs.
- SmallVector<Constant *, 8> UsedArray;
- UsedArray.resize(List.size());
- for (unsigned I = 0, E = List.size(); I != E; ++I)
- UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
- cast<Constant>(&*List[I]), Builder.getPtrTy());
-
- if (UsedArray.empty())
- return;
- ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
-
- auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
- ConstantArray::get(ATy, UsedArray), Name);
-
- GV->setSection("llvm.metadata");
-}
-
Value *OpenMPIRBuilder::getGPUThreadID() {
return Builder.CreateCall(
getOrCreateRuntimeFunction(M,
@@ -6727,41 +6737,6 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
}
-static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
- Module &M) {
- if (List.empty())
- return;
-
- Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
-
- // Convert List to what ConstantArray needs.
- SmallVector<Constant *, 8> UsedArray;
- UsedArray.reserve(List.size());
- for (auto Item : List)
- UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
- cast<Constant>(&*Item), PtrTy));
-
- ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
- auto *GV =
- new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
- llvm::ConstantArray::get(ArrTy, UsedArray), Name);
-
- GV->setSection("llvm.metadata");
-}
-
-static void
-emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
- StringRef FunctionName, OMPTgtExecModeFlags Mode,
- std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
- auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
- auto *GVMode = new llvm::GlobalVariable(
- OMPBuilder.M, Int8Ty, /*isConstant=*/true,
- llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
- Twine(FunctionName, "_exec_mode"));
- GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
- LLVMCompilerUsed.emplace_back(GVMode);
-}
-
static Expected<Function *> createOutlinedFunction(
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
@@ -6794,12 +6769,9 @@ static Expected<Function *> createOutlinedFunction(
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
if (OMPBuilder.Config.isTargetDevice()) {
- std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
- emitExecutionMode(OMPBuilder, Builder, FuncName,
- IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
- : OMP_TGT_EXEC_MODE_GENERIC,
- LLVMCompilerUsed);
- emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
+ Value *ExecMode = OMPBuilder.emitKernelExecutionMode(
+ FuncName, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
+ OMPBuilder.emitUsed("llvm.compiler.used", {ExecMode});
}
// Save insert point.
@@ -7457,8 +7429,8 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
: nullptr;
- for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
- RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+ for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
+ RuntimeAttrs.TargetThreadLimit)) {
Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
Value *NumThreads = InitMaxThreadsClause(TargetVal);
@@ -7518,8 +7490,6 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies, bool HasNowait) {
- assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
- "trip count not expected if IsSPMD=false");
if (!updateToLocation(Loc))
return InsertPointTy();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index a8c786b5886afe..e4845256633b9c 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6459,18 +6459,19 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
return Builder.saveIP();
};
- auto SimpleArgAccessorCB =
- [&](llvm::Argument &, llvm::Value *, llvm::Value *&,
- llvm::OpenMPIRBuilder::InsertPointTy,
- llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
- Builder.restoreIP(CodeGenIP);
- return Builder.saveIP();
- };
+ auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
+ OpenMPIRBuilder::InsertPointTy,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ return Builder.saveIP();
+ };
- llvm::SmallVector<llvm::Value *> Inputs;
- llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
- auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy)
- -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; };
+ SmallVector<Value *> Inputs;
+ OpenMPIRBuilder::MapInfosTy CombinedInfos;
+ auto GenMapInfoCB =
+ [&](OpenMPIRBuilder::InsertPointTy) -> OpenMPIRBuilder::MapInfosTy & {
+ return CombinedInfos;
+ };
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
@@ -6547,19 +6548,20 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
Function *OutlinedFn = nullptr;
- llvm::SmallVector<llvm::Value *> CapturedArgs;
+ SmallVector<Value *> CapturedArgs;
- auto SimpleArgAccessorCB =
- [&](llvm::Argument &, llvm::Value *, llvm::Value *&,
- llvm::OpenMPIRBuilder::InsertPointTy,
- llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
- Builder.restoreIP(CodeGenIP);
- return Builder.saveIP();
- };
+ auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
+ OpenMPIRBuilder::InsertPointTy,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ return Builder.saveIP();
+ };
- llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
- auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy)
- -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; };
+ OpenMPIRBuilder::MapInfosTy CombinedInfos;
+ auto GenMapInfoCB =
+ [&](OpenMPIRBuilder::InsertPointTy) -> OpenMPIRBuilder::MapInfosTy & {
+ return CombinedInfos;
+ };
auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy,
OpenMPIRBuilder::InsertPointTy CodeGenIP)
More information about the llvm-branch-commits
mailing list