[llvm-branch-commits] [clang] [llvm] [mlir] [OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode (PR #116051)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Dec 4 06:16:54 PST 2024
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/116051
>From 2fbe762b53bb6d6ffdce2b5ae3d6de30584ed93b Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 27 Nov 2024 11:33:01 +0000
Subject: [PATCH 1/3] [OMPIRBuilder] Support runtime number of teams and
threads, and SPMD mode
This patch introduces a `TargetKernelRuntimeAttrs` structure to hold
host-evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values
passed to the runtime kernel offloading call.
Additionally, `createTarget` is extended to take an `IsSPMD` flag, used to
influence target device code generation.
---
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 26 +-
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 125 +++++++--
.../Frontend/OpenMPIRBuilderTest.cpp | 256 +++++++++++++++++-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 10 +-
4 files changed, 383 insertions(+), 34 deletions(-)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index f475e34497105fd..444bc280df9f89b 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2237,6 +2237,26 @@ class OpenMPIRBuilder {
int32_t MinThreads = 1;
};
+ /// Container to pass LLVM IR runtime values or constants related to the
+ /// number of teams and threads with which the kernel must be launched, as
+ /// well as the trip count of the SPMD loop, if it is an SPMD kernel. These
+ /// must be defined in the host prior to the call to the kernel launch OpenMP
+ /// RTL function.
+ struct TargetKernelRuntimeAttrs {
+ SmallVector<Value *, 3> MaxTeams = {nullptr};
+ Value *MinTeams = nullptr;
+ SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
+ SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
+
+ /// 'parallel' construct 'num_threads' clause value, if present and it is a
+ /// target SPMD kernel.
+ Value *MaxThreads = nullptr;
+
+ /// Total number of iterations of the target SPMD kernel or null if it is a
+ /// generic kernel.
+ Value *LoopTripCount = nullptr;
+ };
+
/// Data structure that contains the needed information to construct the
/// kernel args vector.
struct TargetKernelArgs {
@@ -2905,11 +2925,14 @@ class OpenMPIRBuilder {
///
/// \param Loc where the target data construct was encountered.
/// \param IsOffloadEntry whether it is an offload entry.
+ /// \param IsSPMD whether it is a target SPMD kernel.
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
/// \param DefaultAttrs Structure containing the default numbers of threads
/// and teams to launch the kernel with.
+ /// \param RuntimeAttrs Structure containing the runtime numbers of threads
+ /// and teams to launch the kernel with.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
/// \param BodyGenCB Callback that will generate the region code.
@@ -2919,11 +2942,12 @@ class OpenMPIRBuilder {
// dependency information as passed in the depend clause
// \param HasNowait Whether the target construct has a `nowait` clause or not.
InsertPointOrErrorTy createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry,
+ const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
TargetBodyGenCallbackTy BodyGenCB,
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 4c4d8f867fba511..cc299a9f46ce788 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6731,8 +6731,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
}
+static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
+ Module &M) {
+ if (List.empty())
+ return;
+
+ Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
+
+ // Convert List to what ConstantArray needs.
+ SmallVector<Constant *, 8> UsedArray;
+ UsedArray.reserve(List.size());
+ for (auto Item : List)
+ UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+ cast<Constant>(&*Item), PtrTy));
+
+ ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
+ auto *GV =
+ new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
+ llvm::ConstantArray::get(ArrTy, UsedArray), Name);
+
+ GV->setSection("llvm.metadata");
+}
+
+static void
+emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+ StringRef FunctionName, OMPTgtExecModeFlags Mode,
+ std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
+ auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
+ auto *GVMode = new llvm::GlobalVariable(
+ OMPBuilder.M, Int8Ty, /*isConstant=*/true,
+ llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
+ Twine(FunctionName, "_exec_mode"));
+ GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
+ LLVMCompilerUsed.emplace_back(GVMode);
+}
+
static Expected<Function *> createOutlinedFunction(
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6762,6 +6797,15 @@ static Expected<Function *> createOutlinedFunction(
auto Func =
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
+ if (OMPBuilder.Config.isTargetDevice()) {
+ std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
+ emitExecutionMode(OMPBuilder, Builder, FuncName,
+ IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
+ : OMP_TGT_EXEC_MODE_GENERIC,
+ LLVMCompilerUsed);
+ emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
+ }
+
// Save insert point.
IRBuilder<>::InsertPointGuard IPG(Builder);
// If there's a DISubprogram associated with current function, then
@@ -6802,7 +6846,7 @@ static Expected<Function *> createOutlinedFunction(
// Insert target init call in the device compilation pass.
if (OMPBuilder.Config.isTargetDevice())
Builder.restoreIP(
- OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
+ OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs));
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
@@ -6998,7 +7042,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
static Error emitTargetOutlinedFunction(
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
- TargetRegionEntryInfo &EntryInfo,
+ bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
Function *&OutlinedFn, Constant *&OutlinedFnID,
SmallVectorImpl<Value *> &Inputs,
@@ -7007,7 +7051,7 @@ static Error emitTargetOutlinedFunction(
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
[&](StringRef EntryFnName) {
- return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+ return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs,
EntryFnName, Inputs, CBFunc,
ArgAccessorFuncCB);
};
@@ -7307,6 +7351,7 @@ static void
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+ const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
Function *OutlinedFn, Constant *OutlinedFnID,
SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7388,11 +7433,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
/*ForEndCall=*/false);
SmallVector<Value *, 3> NumTeamsC;
+ for (auto [DefaultVal, RuntimeVal] :
+ zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+ NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
+
+ // Calculate number of threads: 0 if no clauses specified, otherwise it is the
+ // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+ if (Clause)
+ Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+ /*isSigned=*/false);
+ return Clause;
+ };
+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+ if (Clause)
+ Result = Result
+ ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
+ Result, Clause)
+ : Clause;
+ };
+
+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
SmallVector<Value *, 3> NumThreadsC;
- for (auto V : DefaultAttrs.MaxTeams)
- NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
- for (auto V : DefaultAttrs.MaxThreads)
- NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+ Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
+ ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+ : nullptr;
+
+ for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
+ RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+ Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+ Value *NumThreads = InitMaxThreadsClause(TargetVal);
+
+ CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+ CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+
+ NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+ }
unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
@@ -7401,14 +7478,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);
- // TODO: Use correct NumIterations
- Value *NumIterations = Builder.getInt64(0);
+
+ Value *TripCount = RuntimeAttrs.LoopTripCount
+ ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+ Builder.getInt64Ty(),
+ /*isSigned=*/false)
+ : Builder.getInt64(0);
+
// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);
- KArgs = OpenMPIRBuilder::TargetKernelArgs(
- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
- DynCGGroupMem, HasNoWait);
+ KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+ NumTeamsC, NumThreadsC,
+ DynCGGroupMem, HasNoWait);
// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
@@ -7430,13 +7512,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
}
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
- InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
+ const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
+ InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
+ TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies, bool HasNowait) {
+ assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
+ "trip count not expected if IsSPMD=false");
if (!updateToLocation(Loc))
return InsertPointTy();
@@ -7449,16 +7535,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// the target region itself is generated using the callbacks CBFunc
// and ArgAccessorFuncCB
if (Error Err = emitTargetOutlinedFunction(
- *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
- OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
+ *this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs,
+ OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
return Err;
// If we are not on the target device, then we need to generate code
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
- emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
- OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
+ emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+ OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
+ HasNowait);
return Builder.saveIP();
}
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index b0688d6215e42d3..a8c786b5886afe0 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6123,7 +6123,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
OMPBuilder.setConfig(Config);
F->setName("func");
IRBuilder<> Builder(BB);
- auto Int32Ty = Builder.getInt32Ty();
+ auto *Int32Ty = Builder.getInt32Ty();
AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr");
AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr");
@@ -6183,11 +6183,15 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
- /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
- OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
- Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
- GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ /*MaxTeams=*/{10}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+ RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20);
+ RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
+ RuntimeAttrs.MaxThreads = Builder.getInt32(40);
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, Builder.saveIP(),
+ Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
+ GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
OMPBuilder.finalize();
@@ -6207,6 +6211,43 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
StringRef FunctionName = KernelLaunchFunc->getName();
EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
+ // Check num_teams and num_threads in call arguments
+ EXPECT_TRUE(Call->arg_size() >= 4);
+ Value *NumTeamsArg = Call->getArgOperand(2);
+ EXPECT_TRUE(isa<ConstantInt>(NumTeamsArg));
+ EXPECT_EQ(10U, cast<ConstantInt>(NumTeamsArg)->getZExtValue());
+ Value *NumThreadsArg = Call->getArgOperand(3);
+ EXPECT_TRUE(isa<ConstantInt>(NumThreadsArg));
+ EXPECT_EQ(20U, cast<ConstantInt>(NumThreadsArg)->getZExtValue());
+
+ // Check num_teams and num_threads kernel arguments (use number 5 starting
+ // from the end and counting the call to __tgt_target_kernel as the first use)
+ Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+ EXPECT_TRUE(KernelArgs->getNumUses() >= 4);
+ Value *NumTeamsGetElemPtr = *std::next(KernelArgs->user_begin(), 3);
+ EXPECT_TRUE(isa<GetElementPtrInst>(NumTeamsGetElemPtr));
+ Value *NumTeamsStore = NumTeamsGetElemPtr->getUniqueUndroppableUser();
+ EXPECT_TRUE(isa<StoreInst>(NumTeamsStore));
+ Value *NumTeamsStoreArg = cast<StoreInst>(NumTeamsStore)->getValueOperand();
+ EXPECT_TRUE(isa<ConstantDataSequential>(NumTeamsStoreArg));
+ auto *NumTeamsStoreValue = cast<ConstantDataSequential>(NumTeamsStoreArg);
+ EXPECT_EQ(3U, NumTeamsStoreValue->getNumElements());
+ EXPECT_EQ(10U, NumTeamsStoreValue->getElementAsInteger(0));
+ EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(1));
+ EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(2));
+ Value *NumThreadsGetElemPtr = *std::next(KernelArgs->user_begin(), 2);
+ EXPECT_TRUE(isa<GetElementPtrInst>(NumThreadsGetElemPtr));
+ Value *NumThreadsStore = NumThreadsGetElemPtr->getUniqueUndroppableUser();
+ EXPECT_TRUE(isa<StoreInst>(NumThreadsStore));
+ Value *NumThreadsStoreArg =
+ cast<StoreInst>(NumThreadsStore)->getValueOperand();
+ EXPECT_TRUE(isa<ConstantDataSequential>(NumThreadsStoreArg));
+ auto *NumThreadsStoreValue = cast<ConstantDataSequential>(NumThreadsStoreArg);
+ EXPECT_EQ(3U, NumThreadsStoreValue->getNumElements());
+ EXPECT_EQ(20U, NumThreadsStoreValue->getElementAsInteger(0));
+ EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(1));
+ EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(2));
+
// Check the fallback call
BasicBlock *FallbackBlock = Branch->getSuccessor(0);
Iter = FallbackBlock->rbegin();
@@ -6297,9 +6338,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
- CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+ EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+ BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
@@ -6378,6 +6421,197 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
auto *ExitBlock = EntryBlockBranch->getSuccessor(1);
EXPECT_EQ(ExitBlock->getName(), "worker.exit");
EXPECT_TRUE(isa<ReturnInst>(ExitBlock->getFirstNonPHI()));
+
+ // Check global exec_mode.
+ GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used");
+ EXPECT_NE(Used, nullptr);
+ Constant *UsedInit = Used->getInitializer();
+ EXPECT_NE(UsedInit, nullptr);
+ EXPECT_TRUE(isa<ConstantArray>(UsedInit));
+ auto *UsedInitData = cast<ConstantArray>(UsedInit);
+ EXPECT_EQ(1U, UsedInitData->getNumOperands());
+ Constant *ExecMode = UsedInitData->getOperand(0);
+ EXPECT_TRUE(isa<GlobalVariable>(ExecMode));
+ Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer();
+ EXPECT_NE(ExecModeValue, nullptr);
+ EXPECT_TRUE(isa<ConstantInt>(ExecModeValue));
+ EXPECT_EQ(OMP_TGT_EXEC_MODE_GENERIC,
+ cast<ConstantInt>(ExecModeValue)->getZExtValue());
+}
+
+TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ OpenMPIRBuilderConfig Config(/*IsTargetDevice=*/false, /*IsGPU=*/false,
+ /*OpenMPOffloadMandatory=*/false,
+ /*HasRequiresReverseOffload=*/false,
+ /*HasRequiresUnifiedAddress=*/false,
+ /*HasRequiresUnifiedSharedMemory=*/false,
+ /*HasRequiresDynamicAllocators=*/false);
+ OMPBuilder.setConfig(Config);
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+
+ auto BodyGenCB = [&](InsertPointTy,
+ InsertPointTy CodeGenIP) -> InsertPointTy {
+ Builder.restoreIP(CodeGenIP);
+ return Builder.saveIP();
+ };
+
+ auto SimpleArgAccessorCB =
+ [&](llvm::Argument &, llvm::Value *, llvm::Value *&,
+ llvm::OpenMPIRBuilder::InsertPointTy,
+ llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ return Builder.saveIP();
+ };
+
+ llvm::SmallVector<llvm::Value *> Inputs;
+ llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
+ auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy)
+ -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; };
+
+ TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
+ OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
+ OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+ RuntimeAttrs.LoopTripCount = Builder.getInt64(1000);
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, Builder.saveIP(),
+ Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
+ GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ assert(AfterIP && "unexpected error");
+ Builder.restoreIP(*AfterIP);
+ OMPBuilder.finalize();
+ Builder.CreateRetVoid();
+
+ // Check the kernel launch sequence
+ auto Iter = F->getEntryBlock().rbegin();
+ EXPECT_TRUE(isa<BranchInst>(&*(Iter)));
+ BranchInst *Branch = dyn_cast<BranchInst>(&*(Iter));
+ EXPECT_TRUE(isa<CmpInst>(&*(++Iter)));
+ EXPECT_TRUE(isa<CallInst>(&*(++Iter)));
+ CallInst *Call = dyn_cast<CallInst>(&*(Iter));
+
+ // Check that the kernel launch function is called
+ Function *KernelLaunchFunc = Call->getCalledFunction();
+ EXPECT_NE(KernelLaunchFunc, nullptr);
+ StringRef FunctionName = KernelLaunchFunc->getName();
+ EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
+
+ // Check the trip count kernel argument (use number 5 starting from the end
+ // and counting the call to __tgt_target_kernel as the first use)
+ Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+ EXPECT_TRUE(KernelArgs->getNumUses() >= 6);
+ Value *TripCountGetElemPtr = *std::next(KernelArgs->user_begin(), 5);
+ EXPECT_TRUE(isa<GetElementPtrInst>(TripCountGetElemPtr));
+ Value *TripCountStore = TripCountGetElemPtr->getUniqueUndroppableUser();
+ EXPECT_TRUE(isa<StoreInst>(TripCountStore));
+ Value *TripCountStoreArg = cast<StoreInst>(TripCountStore)->getValueOperand();
+ EXPECT_TRUE(isa<ConstantInt>(TripCountStoreArg));
+ EXPECT_EQ(1000U, cast<ConstantInt>(TripCountStoreArg)->getZExtValue());
+
+ // Check the fallback call
+ BasicBlock *FallbackBlock = Branch->getSuccessor(0);
+ Iter = FallbackBlock->rbegin();
+ CallInst *FCall = dyn_cast<CallInst>(&*(++Iter));
+ // 'F' has a dummy DISubprogram which causes OutlinedFunc to also
+ // have a DISubprogram. In this case, the call to OutlinedFunc needs
+ // to have a debug loc, otherwise verifier will complain.
+ FCall->setDebugLoc(DL);
+ EXPECT_NE(FCall, nullptr);
+
+ // Check that the outlined function exists with the expected prefix
+ Function *OutlinedFunc = FCall->getCalledFunction();
+ EXPECT_NE(OutlinedFunc, nullptr);
+ StringRef FunctionName2 = OutlinedFunc->getName();
+ EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading"));
+
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+}
+
+TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.setConfig(
+ OpenMPIRBuilderConfig(/*IsTargetDevice=*/true, /*IsGPU=*/false,
+ /*OpenMPOffloadMandatory=*/false,
+ /*HasRequiresReverseOffload=*/false,
+ /*HasRequiresUnifiedAddress=*/false,
+ /*HasRequiresUnifiedSharedMemory=*/false,
+ /*HasRequiresDynamicAllocators=*/false));
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+ OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+
+ Function *OutlinedFn = nullptr;
+ llvm::SmallVector<llvm::Value *> CapturedArgs;
+
+ auto SimpleArgAccessorCB =
+ [&](llvm::Argument &, llvm::Value *, llvm::Value *&,
+ llvm::OpenMPIRBuilder::InsertPointTy,
+ llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ return Builder.saveIP();
+ };
+
+ llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
+ auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy)
+ -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; };
+
+ auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP)
+ -> OpenMPIRBuilder::InsertPointTy {
+ Builder.restoreIP(CodeGenIP);
+ OutlinedFn = CodeGenIP.getBlock()->getParent();
+ return Builder.saveIP();
+ };
+
+ IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(),
+ F->getEntryBlock().getFirstInsertionPt());
+ TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
+ /*Line=*/3, /*Count=*/0);
+
+ OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, EntryIP, EntryIP,
+ EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+ BodyGenCB, SimpleArgAccessorCB);
+ assert(AfterIP && "unexpected error");
+ Builder.restoreIP(*AfterIP);
+
+ Builder.CreateRetVoid();
+ OMPBuilder.finalize();
+
+ // Check outlined function
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+ EXPECT_NE(OutlinedFn, nullptr);
+ EXPECT_NE(F, OutlinedFn);
+
+ EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage());
+ // Account for the "implicit" first argument.
+ EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3");
+ EXPECT_EQ(OutlinedFn->arg_size(), 1U);
+
+ // Check global exec_mode.
+ GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used");
+ EXPECT_NE(Used, nullptr);
+ Constant *UsedInit = Used->getInitializer();
+ EXPECT_NE(UsedInit, nullptr);
+ EXPECT_TRUE(isa<ConstantArray>(UsedInit));
+ auto *UsedInitData = cast<ConstantArray>(UsedInit);
+ EXPECT_EQ(1U, UsedInitData->getNumOperands());
+ Constant *ExecMode = UsedInitData->getOperand(0);
+ EXPECT_TRUE(isa<GlobalVariable>(ExecMode));
+ Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer();
+ EXPECT_NE(ExecModeValue, nullptr);
+ EXPECT_TRUE(isa<ConstantInt>(ExecModeValue));
+ EXPECT_EQ(OMP_TGT_EXEC_MODE_SPMD,
+ cast<ConstantInt>(ExecModeValue)->getZExtValue());
}
TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
@@ -6448,9 +6682,11 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
- CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+ EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+ BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index cca2613ce102afa..f30ba2c29261625 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3951,9 +3951,11 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
allocaIP, codeGenIP);
};
- // TODO: Populate default attributes based on the construct and clauses.
+ // TODO: Populate default and runtime attributes based on the construct and
+ // clauses.
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
llvm::SmallVector<llvm::Value *, 4> kernelInput;
for (size_t i = 0; i < mapVars.size(); ++i) {
@@ -3973,9 +3975,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
- ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
- defaultAttrs, kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds,
- targetOp.getNowait());
+ ompLoc, isOffloadEntry, /*IsSPMD=*/false, allocaIP, builder.saveIP(),
+ entryInfo, defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB,
+ bodyCB, argAccessorCB, dds, targetOp.getNowait());
if (failed(handleError(afterIP, opInst)))
return failure();
>From e9ea3a501bb97a4ddaa85b795874e343cae40f2b Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 27 Nov 2024 12:08:46 +0000
Subject: [PATCH 2/3] Address review comments
---
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 10 +-
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 104 +++++++-----------
.../Frontend/OpenMPIRBuilderTest.cpp | 46 ++++----
3 files changed, 68 insertions(+), 92 deletions(-)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 444bc280df9f89b..3a640fbd7336951 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1387,9 +1387,6 @@ class OpenMPIRBuilder {
/// Supporting functions for Reductions CodeGen.
private:
- /// Emit the llvm.used metadata.
- void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List);
-
/// Get the id of the current thread on the GPU.
Value *getGPUThreadID();
@@ -2011,6 +2008,13 @@ class OpenMPIRBuilder {
/// Value.
GlobalValue *createGlobalFlag(unsigned Value, StringRef Name);
+ /// Emit the llvm.used metadata.
+ void emitUsed(StringRef Name, ArrayRef<llvm::WeakTrackingVH> List);
+
+ /// Emit the kernel execution mode.
+ GlobalVariable *emitKernelExecutionMode(StringRef KernelName,
+ omp::OMPTgtExecModeFlags Mode);
+
/// Generate control flow and cleanup for cancellation.
///
/// \param CancelFlag Flag indicating if the cancellation is performed.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index cc299a9f46ce788..dcf2515311eabcd 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -830,6 +830,38 @@ GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
return GV;
}
+void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
+ if (List.empty())
+ return;
+
+ // Convert List to what ConstantArray needs.
+ SmallVector<Constant *, 8> UsedArray;
+ UsedArray.resize(List.size());
+ for (unsigned I = 0, E = List.size(); I != E; ++I)
+ UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+ cast<Constant>(&*List[I]), Builder.getPtrTy());
+
+ if (UsedArray.empty())
+ return;
+ ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
+
+ auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
+ ConstantArray::get(ATy, UsedArray), Name);
+
+ GV->setSection("llvm.metadata");
+}
+
+GlobalVariable *
+OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
+ OMPTgtExecModeFlags Mode) {
+ auto *Int8Ty = Builder.getInt8Ty();
+ auto *GVMode = new GlobalVariable(
+ M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
+ ConstantInt::get(Int8Ty, Mode), Twine(KernelName, "_exec_mode"));
+ GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
+ return GVMode;
+}
+
Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
uint32_t SrcLocStrSize,
IdentFlag LocFlags,
@@ -2246,28 +2278,6 @@ static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
}
-void OpenMPIRBuilder::emitUsed(StringRef Name,
- std::vector<WeakTrackingVH> &List) {
- if (List.empty())
- return;
-
- // Convert List to what ConstantArray needs.
- SmallVector<Constant *, 8> UsedArray;
- UsedArray.resize(List.size());
- for (unsigned I = 0, E = List.size(); I != E; ++I)
- UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
- cast<Constant>(&*List[I]), Builder.getPtrTy());
-
- if (UsedArray.empty())
- return;
- ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
-
- auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
- ConstantArray::get(ATy, UsedArray), Name);
-
- GV->setSection("llvm.metadata");
-}
-
Value *OpenMPIRBuilder::getGPUThreadID() {
return Builder.CreateCall(
getOrCreateRuntimeFunction(M,
@@ -6731,41 +6741,6 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
}
-static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
- Module &M) {
- if (List.empty())
- return;
-
- Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
-
- // Convert List to what ConstantArray needs.
- SmallVector<Constant *, 8> UsedArray;
- UsedArray.reserve(List.size());
- for (auto Item : List)
- UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
- cast<Constant>(&*Item), PtrTy));
-
- ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
- auto *GV =
- new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
- llvm::ConstantArray::get(ArrTy, UsedArray), Name);
-
- GV->setSection("llvm.metadata");
-}
-
-static void
-emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
- StringRef FunctionName, OMPTgtExecModeFlags Mode,
- std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
- auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
- auto *GVMode = new llvm::GlobalVariable(
- OMPBuilder.M, Int8Ty, /*isConstant=*/true,
- llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
- Twine(FunctionName, "_exec_mode"));
- GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
- LLVMCompilerUsed.emplace_back(GVMode);
-}
-
static Expected<Function *> createOutlinedFunction(
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
@@ -6798,12 +6773,9 @@ static Expected<Function *> createOutlinedFunction(
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
if (OMPBuilder.Config.isTargetDevice()) {
- std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
- emitExecutionMode(OMPBuilder, Builder, FuncName,
- IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
- : OMP_TGT_EXEC_MODE_GENERIC,
- LLVMCompilerUsed);
- emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
+ Value *ExecMode = OMPBuilder.emitKernelExecutionMode(
+ FuncName, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
+ OMPBuilder.emitUsed("llvm.compiler.used", {ExecMode});
}
// Save insert point.
@@ -7460,8 +7432,8 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
: nullptr;
- for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
- RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+ for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
+ RuntimeAttrs.TargetThreadLimit)) {
Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
Value *NumThreads = InitMaxThreadsClause(TargetVal);
@@ -7521,8 +7493,6 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies, bool HasNowait) {
- assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
- "trip count not expected if IsSPMD=false");
if (!updateToLocation(Loc))
return InsertPointTy();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index a8c786b5886afe0..e4845256633b9c8 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6459,18 +6459,19 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
return Builder.saveIP();
};
- auto SimpleArgAccessorCB =
- [&](llvm::Argument &, llvm::Value *, llvm::Value *&,
- llvm::OpenMPIRBuilder::InsertPointTy,
- llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
- Builder.restoreIP(CodeGenIP);
- return Builder.saveIP();
- };
+ auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
+ OpenMPIRBuilder::InsertPointTy,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ return Builder.saveIP();
+ };
- llvm::SmallVector<llvm::Value *> Inputs;
- llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
- auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy)
- -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; };
+ SmallVector<Value *> Inputs;
+ OpenMPIRBuilder::MapInfosTy CombinedInfos;
+ auto GenMapInfoCB =
+ [&](OpenMPIRBuilder::InsertPointTy) -> OpenMPIRBuilder::MapInfosTy & {
+ return CombinedInfos;
+ };
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
@@ -6547,19 +6548,20 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
Function *OutlinedFn = nullptr;
- llvm::SmallVector<llvm::Value *> CapturedArgs;
+ SmallVector<Value *> CapturedArgs;
- auto SimpleArgAccessorCB =
- [&](llvm::Argument &, llvm::Value *, llvm::Value *&,
- llvm::OpenMPIRBuilder::InsertPointTy,
- llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
- Builder.restoreIP(CodeGenIP);
- return Builder.saveIP();
- };
+ auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
+ OpenMPIRBuilder::InsertPointTy,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ return Builder.saveIP();
+ };
- llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
- auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy)
- -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; };
+ OpenMPIRBuilder::MapInfosTy CombinedInfos;
+ auto GenMapInfoCB =
+ [&](OpenMPIRBuilder::InsertPointTy) -> OpenMPIRBuilder::MapInfosTy & {
+ return CombinedInfos;
+ };
auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy,
OpenMPIRBuilder::InsertPointTy CodeGenIP)
>From b1e4eb5699afa9582a136c589adb9c256cc0bc66 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 4 Dec 2024 14:16:17 +0000
Subject: [PATCH 3/3] Fine-grained control of kernel execution mode
---
clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp | 7 ++++-
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 22 +++++++-------
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 30 +++++++++----------
.../Frontend/OpenMPIRBuilderTest.cpp | 15 ++++++----
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 7 +++--
5 files changed, 47 insertions(+), 34 deletions(-)
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
index 659783a813c83ef..515dbe379eb6e3d 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -20,6 +20,7 @@
#include "clang/AST/StmtVisitor.h"
#include "clang/Basic/Cuda.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
#include "llvm/Frontend/OpenMP/OMPGridValues.h"
using namespace clang;
@@ -748,7 +749,11 @@ void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);
CGBuilderTy &Bld = CGF.Builder;
- Bld.restoreIP(OMPBuilder.createTargetInit(Bld, IsSPMD, Attrs));
+ Bld.restoreIP(OMPBuilder.createTargetInit(
+ Bld,
+ IsSPMD ? llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD
+ : llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
+ Attrs));
if (!IsSPMD)
emitGenericVarsProlog(CGF, EST.Loc);
}
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 3a640fbd7336951..580b2b3e2341580 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2243,21 +2243,21 @@ class OpenMPIRBuilder {
/// Container to pass LLVM IR runtime values or constants related to the
/// number of teams and threads with which the kernel must be launched, as
- /// well as the trip count of the SPMD loop, if it is an SPMD kernel. These
- /// must be defined in the host prior to the call to the kernel launch OpenMP
- /// RTL function.
+ /// well as the trip count of the loop, if it is an SPMD or Generic-SPMD
+ /// kernel. These must be defined in the host prior to the call to the kernel
+ /// launch OpenMP RTL function.
struct TargetKernelRuntimeAttrs {
SmallVector<Value *, 3> MaxTeams = {nullptr};
Value *MinTeams = nullptr;
SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
- /// 'parallel' construct 'num_threads' clause value, if present and it is a
- /// target SPMD kernel.
+ /// 'parallel' construct 'num_threads' clause value, if present and it is an
+ /// SPMD kernel.
Value *MaxThreads = nullptr;
- /// Total number of iterations of the target SPMD kernel or null if it is a
- /// generic kernel.
+ /// Total number of iterations of the SPMD or Generic-SPMD kernel or null if
+ /// it is a generic kernel.
Value *LoopTripCount = nullptr;
};
@@ -2763,11 +2763,12 @@ class OpenMPIRBuilder {
/// Create a runtime call for kmpc_target_init
///
/// \param Loc The insert and source location description.
+ /// \param ExecFlags Kernel execution mode flags.
/// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
/// \param Attrs Structure containing the default numbers of threads and teams
/// to launch the kernel with.
InsertPointTy createTargetInit(
- const LocationDescription &Loc, bool IsSPMD,
+ const LocationDescription &Loc, omp::OMPTgtExecModeFlags ExecFlags,
const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
/// Create a runtime call for kmpc_target_deinit
@@ -2929,7 +2930,7 @@ class OpenMPIRBuilder {
///
/// \param Loc where the target data construct was encountered.
/// \param IsOffloadEntry whether it is an offload entry.
- /// \param IsSPMD whether it is a target SPMD kernel.
+ /// \param ExecFlags kernel execution mode flags.
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
@@ -2946,7 +2947,8 @@ class OpenMPIRBuilder {
// dependency information as passed in the depend clause
// \param HasNowait Whether the target construct has a `nowait` clause or not.
InsertPointOrErrorTy createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
+ const LocationDescription &Loc, bool IsOffloadEntry,
+ omp::OMPTgtExecModeFlags ExecFlags,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index dcf2515311eabcd..28f85460624dd63 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6124,7 +6124,7 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
}
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
- const LocationDescription &Loc, bool IsSPMD,
+ const LocationDescription &Loc, omp::OMPTgtExecModeFlags ExecFlags,
const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
"expected num_threads and num_teams to be specified");
@@ -6135,9 +6135,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
uint32_t SrcLocStrSize;
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
- Constant *IsSPMDVal = ConstantInt::getSigned(
- Int8, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
- Constant *UseGenericStateMachineVal = ConstantInt::getSigned(Int8, !IsSPMD);
+ Constant *IsSPMDVal = ConstantInt::getSigned(Int8, ExecFlags);
+ Constant *UseGenericStateMachineVal =
+ ConstantInt::getSigned(Int8, ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
@@ -6742,7 +6742,8 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
}
static Expected<Function *> createOutlinedFunction(
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+ omp::OMPTgtExecModeFlags ExecFlags,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6773,8 +6774,7 @@ static Expected<Function *> createOutlinedFunction(
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
if (OMPBuilder.Config.isTargetDevice()) {
- Value *ExecMode = OMPBuilder.emitKernelExecutionMode(
- FuncName, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
+ Value *ExecMode = OMPBuilder.emitKernelExecutionMode(FuncName, ExecFlags);
OMPBuilder.emitUsed("llvm.compiler.used", {ExecMode});
}
@@ -6818,7 +6818,7 @@ static Expected<Function *> createOutlinedFunction(
// Insert target init call in the device compilation pass.
if (OMPBuilder.Config.isTargetDevice())
Builder.restoreIP(
- OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs));
+ OMPBuilder.createTargetInit(Builder, ExecFlags, DefaultAttrs));
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
@@ -7014,7 +7014,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
static Error emitTargetOutlinedFunction(
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
- bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
+ omp::OMPTgtExecModeFlags ExecFlags, TargetRegionEntryInfo &EntryInfo,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
Function *&OutlinedFn, Constant *&OutlinedFnID,
SmallVectorImpl<Value *> &Inputs,
@@ -7023,8 +7023,8 @@ static Error emitTargetOutlinedFunction(
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
[&](StringRef EntryFnName) {
- return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs,
- EntryFnName, Inputs, CBFunc,
+ return createOutlinedFunction(OMPBuilder, Builder, ExecFlags,
+ DefaultAttrs, EntryFnName, Inputs, CBFunc,
ArgAccessorFuncCB);
};
@@ -7484,9 +7484,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
}
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
- InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
- TargetRegionEntryInfo &EntryInfo,
+ const LocationDescription &Loc, bool IsOffloadEntry,
+ omp::OMPTgtExecModeFlags ExecFlags, InsertPointTy AllocaIP,
+ InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
const TargetKernelRuntimeAttrs &RuntimeAttrs,
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
@@ -7505,7 +7505,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// the target region itself is generated using the callbacks CBFunc
// and ArgAccessorFuncCB
if (Error Err = emitTargetOutlinedFunction(
- *this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs,
+ *this, Builder, IsOffloadEntry, ExecFlags, EntryInfo, DefaultAttrs,
OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
return Err;
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index e4845256633b9c8..90a0a92888310cc 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6189,7 +6189,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
RuntimeAttrs.MaxThreads = Builder.getInt32(40);
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, Builder.saveIP(),
+ OmpLoc, /*IsOffloadEntry=*/true,
+ omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, Builder.saveIP(),
Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
@@ -6340,7 +6341,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+ Loc, /*IsOffloadEntry=*/true,
+ omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, EntryIP, EntryIP,
EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
@@ -6480,7 +6482,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
RuntimeAttrs.LoopTripCount = Builder.getInt64(1000);
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, Builder.saveIP(),
+ OmpLoc, /*IsOffloadEntry=*/true,
+ omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, Builder.saveIP(),
Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
@@ -6580,7 +6583,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, EntryIP, EntryIP,
+ Loc, /*IsOffloadEntry=*/true,
+ omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, EntryIP, EntryIP,
EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
@@ -6686,7 +6690,8 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+ Loc, /*IsOffloadEntry=*/true,
+ omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, EntryIP, EntryIP,
EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index f30ba2c29261625..acdbbcd5eafa21e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3975,9 +3975,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
- ompLoc, isOffloadEntry, /*IsSPMD=*/false, allocaIP, builder.saveIP(),
- entryInfo, defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB,
- bodyCB, argAccessorCB, dds, targetOp.getNowait());
+ ompLoc, isOffloadEntry, llvm::omp::OMP_TGT_EXEC_MODE_GENERIC,
+ allocaIP, builder.saveIP(), entryInfo, defaultAttrs, runtimeAttrs,
+ kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds,
+ targetOp.getNowait());
if (failed(handleError(afterIP, opInst)))
return failure();
More information about the llvm-branch-commits
mailing list