[clang] [llvm] [mlir] [OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode (PR #116051)
Sergio Afonso via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 14 03:25:44 PST 2025
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/116051
>From 0c19f7119c1da0646466b0eb1c3c77faedabaebf Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 27 Nov 2024 11:33:01 +0000
Subject: [PATCH] [OMPIRBuilder] Support runtime number of teams and threads,
and SPMD mode
This patch introduces a `TargetKernelRuntimeAttrs` structure to hold
host-evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values
passed to the runtime kernel offloading call.
Additionally, kernel type information is used to influence target device code
generation and the `IsSPMD` flag is replaced by `ExecFlags`, which provide more
granularity.
---
clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp | 5 +-
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 38 ++-
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 129 +++++---
.../Frontend/OpenMPIRBuilderTest.cpp | 281 ++++++++++++++++--
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 12 +-
5 files changed, 398 insertions(+), 67 deletions(-)
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
index 81993dafae2b03..87c3635ed3f70e 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -20,6 +20,7 @@
#include "clang/AST/StmtVisitor.h"
#include "clang/Basic/Cuda.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
#include "llvm/Frontend/OpenMP/OMPGridValues.h"
using namespace clang;
@@ -745,7 +746,9 @@ void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
CodeGenFunction &CGF,
EntryFunctionState &EST, bool IsSPMD) {
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs Attrs;
- Attrs.IsSPMD = IsSPMD;
+ Attrs.ExecFlags =
+ IsSPMD ? llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD
+ : llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC;
computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);
CGBuilderTy &Bld = CGF.Builder;
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 8ca3bc08b5ad49..7eceec3d8cf8f5 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1389,9 +1389,6 @@ class OpenMPIRBuilder {
/// Supporting functions for Reductions CodeGen.
private:
- /// Emit the llvm.used metadata.
- void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List);
-
/// Get the id of the current thread on the GPU.
Value *getGPUThreadID();
@@ -2013,6 +2010,13 @@ class OpenMPIRBuilder {
/// Value.
GlobalValue *createGlobalFlag(unsigned Value, StringRef Name);
+ /// Emit the llvm.used metadata.
+ void emitUsed(StringRef Name, ArrayRef<llvm::WeakTrackingVH> List);
+
+ /// Emit the kernel execution mode.
+ GlobalVariable *emitKernelExecutionMode(StringRef KernelName,
+ omp::OMPTgtExecModeFlags Mode);
+
/// Generate control flow and cleanup for cancellation.
///
/// \param CancelFlag Flag indicating if the cancellation is performed.
@@ -2233,13 +2237,34 @@ class OpenMPIRBuilder {
/// time. The number of max values will be 1 except for the case where
/// ompx_bare is set.
struct TargetKernelDefaultAttrs {
- bool IsSPMD = false;
+ omp::OMPTgtExecModeFlags ExecFlags =
+ omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC;
SmallVector<int32_t, 3> MaxTeams = {-1};
int32_t MinTeams = 1;
SmallVector<int32_t, 3> MaxThreads = {-1};
int32_t MinThreads = 1;
};
+ /// Container to pass LLVM IR runtime values or constants related to the
+ /// number of teams and threads with which the kernel must be launched, as
+ /// well as the trip count of the loop, if it is an SPMD or Generic-SPMD
+ /// kernel. These must be defined in the host prior to the call to the kernel
+ /// launch OpenMP RTL function.
+ struct TargetKernelRuntimeAttrs {
+ SmallVector<Value *, 3> MaxTeams = {nullptr};
+ Value *MinTeams = nullptr;
+ SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
+ SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
+
+ /// 'parallel' construct 'num_threads' clause value, if present and it is an
+ /// SPMD kernel.
+ Value *MaxThreads = nullptr;
+
+ /// Total number of iterations of the SPMD or Generic-SPMD kernel or null if
+ /// it is a generic kernel.
+ Value *LoopTripCount = nullptr;
+ };
+
/// Data structure that contains the needed information to construct the
/// kernel args vector.
struct TargetKernelArgs {
@@ -2971,7 +2996,9 @@ class OpenMPIRBuilder {
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
- /// \param DefaultAttrs Structure containing the default numbers of threads
+ /// \param DefaultAttrs Structure containing the default attributes, including
+ /// numbers of threads and teams to launch the kernel with.
+ /// \param RuntimeAttrs Structure containing the runtime numbers of threads
/// and teams to launch the kernel with.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
@@ -2987,6 +3014,7 @@ class OpenMPIRBuilder {
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
TargetBodyGenCallbackTy BodyGenCB,
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index df9b35ddd80ca4..3242b38502300d 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -830,6 +830,38 @@ GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
return GV;
}
+void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
+ if (List.empty())
+ return;
+
+ // Convert List to what ConstantArray needs.
+ SmallVector<Constant *, 8> UsedArray;
+ UsedArray.resize(List.size());
+ for (unsigned I = 0, E = List.size(); I != E; ++I)
+ UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+ cast<Constant>(&*List[I]), Builder.getPtrTy());
+
+ if (UsedArray.empty())
+ return;
+ ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
+
+ auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
+ ConstantArray::get(ATy, UsedArray), Name);
+
+ GV->setSection("llvm.metadata");
+}
+
+GlobalVariable *
+OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
+ OMPTgtExecModeFlags Mode) {
+ auto *Int8Ty = Builder.getInt8Ty();
+ auto *GVMode = new GlobalVariable(
+ M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
+ ConstantInt::get(Int8Ty, Mode), Twine(KernelName, "_exec_mode"));
+ GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
+ return GVMode;
+}
+
Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
uint32_t SrcLocStrSize,
IdentFlag LocFlags,
@@ -2260,28 +2292,6 @@ static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
}
-void OpenMPIRBuilder::emitUsed(StringRef Name,
- std::vector<WeakTrackingVH> &List) {
- if (List.empty())
- return;
-
- // Convert List to what ConstantArray needs.
- SmallVector<Constant *, 8> UsedArray;
- UsedArray.resize(List.size());
- for (unsigned I = 0, E = List.size(); I != E; ++I)
- UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
- cast<Constant>(&*List[I]), Builder.getPtrTy());
-
- if (UsedArray.empty())
- return;
- ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
-
- auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
- ConstantArray::get(ATy, UsedArray), Name);
-
- GV->setSection("llvm.metadata");
-}
-
Value *OpenMPIRBuilder::getGPUThreadID() {
return Builder.CreateCall(
getOrCreateRuntimeFunction(M,
@@ -6131,10 +6141,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
uint32_t SrcLocStrSize;
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
- Constant *IsSPMDVal = ConstantInt::getSigned(
- Int8, Attrs.IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
- Constant *UseGenericStateMachineVal =
- ConstantInt::getSigned(Int8, !Attrs.IsSPMD);
+ Constant *IsSPMDVal = ConstantInt::getSigned(Int8, Attrs.ExecFlags);
+ Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
+ Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
@@ -6765,6 +6774,12 @@ static Expected<Function *> createOutlinedFunction(
auto Func =
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
+ if (OMPBuilder.Config.isTargetDevice()) {
+ Value *ExecMode =
+ OMPBuilder.emitKernelExecutionMode(FuncName, DefaultAttrs.ExecFlags);
+ OMPBuilder.emitUsed("llvm.compiler.used", {ExecMode});
+ }
+
// Save insert point.
IRBuilder<>::InsertPointGuard IPG(Builder);
// If there's a DISubprogram associated with current function, then
@@ -7312,6 +7327,7 @@ static void
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+ const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
Function *OutlinedFn, Constant *OutlinedFnID,
SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7393,11 +7409,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
/*ForEndCall=*/false);
SmallVector<Value *, 3> NumTeamsC;
+ for (auto [DefaultVal, RuntimeVal] :
+ zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+ NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
+
+ // Calculate number of threads: 0 if no clauses specified, otherwise it is the
+ // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+ if (Clause)
+ Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+ /*isSigned=*/false);
+ return Clause;
+ };
+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+ if (Clause)
+ Result = Result
+ ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
+ Result, Clause)
+ : Clause;
+ };
+
+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
SmallVector<Value *, 3> NumThreadsC;
- for (auto V : DefaultAttrs.MaxTeams)
- NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
- for (auto V : DefaultAttrs.MaxThreads)
- NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+ Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
+ ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+ : nullptr;
+
+ for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
+ RuntimeAttrs.TargetThreadLimit)) {
+ Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+ Value *NumThreads = InitMaxThreadsClause(TargetVal);
+
+ CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+ CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+
+ NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+ }
unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
@@ -7406,14 +7454,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);
- // TODO: Use correct NumIterations
- Value *NumIterations = Builder.getInt64(0);
+
+ Value *TripCount = RuntimeAttrs.LoopTripCount
+ ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+ Builder.getInt64Ty(),
+ /*isSigned=*/false)
+ : Builder.getInt64(0);
+
// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);
- KArgs = OpenMPIRBuilder::TargetKernelArgs(
- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
- DynCGGroupMem, HasNoWait);
+ KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+ NumTeamsC, NumThreadsC,
+ DynCGGroupMem, HasNoWait);
// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
@@ -7438,6 +7491,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7462,8 +7516,9 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
- emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
- OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
+ emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+ OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
+ HasNowait);
return Builder.saveIP();
}
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 04ecd7ef327d5a..11f13beb9865c0 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6170,7 +6170,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
OMPBuilder.setConfig(Config);
F->setName("func");
IRBuilder<> Builder(BB);
- auto Int32Ty = Builder.getInt32Ty();
+ auto *Int32Ty = Builder.getInt32Ty();
AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr");
AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr");
@@ -6229,16 +6229,22 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
- /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
- /*MinThreads=*/0};
+ /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
+ /*MaxTeams=*/{10}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20);
+ RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
+ RuntimeAttrs.MaxThreads = Builder.getInt32(40);
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
- Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
- GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+ Builder.saveIP(), EntryInfo, DefaultAttrs,
+ RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB,
+ SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
+
OMPBuilder.finalize();
Builder.CreateRetVoid();
@@ -6256,6 +6262,43 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
StringRef FunctionName = KernelLaunchFunc->getName();
EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
+ // Check num_teams and num_threads in call arguments
+ EXPECT_TRUE(Call->arg_size() >= 4);
+ Value *NumTeamsArg = Call->getArgOperand(2);
+ EXPECT_TRUE(isa<ConstantInt>(NumTeamsArg));
+ EXPECT_EQ(10U, cast<ConstantInt>(NumTeamsArg)->getZExtValue());
+ Value *NumThreadsArg = Call->getArgOperand(3);
+ EXPECT_TRUE(isa<ConstantInt>(NumThreadsArg));
+ EXPECT_EQ(20U, cast<ConstantInt>(NumThreadsArg)->getZExtValue());
+
+ // Check num_teams and num_threads kernel arguments (use number 5 starting
+ // from the end and counting the call to __tgt_target_kernel as the first use)
+ Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+ EXPECT_TRUE(KernelArgs->getNumUses() >= 4);
+ Value *NumTeamsGetElemPtr = *std::next(KernelArgs->user_begin(), 3);
+ EXPECT_TRUE(isa<GetElementPtrInst>(NumTeamsGetElemPtr));
+ Value *NumTeamsStore = NumTeamsGetElemPtr->getUniqueUndroppableUser();
+ EXPECT_TRUE(isa<StoreInst>(NumTeamsStore));
+ Value *NumTeamsStoreArg = cast<StoreInst>(NumTeamsStore)->getValueOperand();
+ EXPECT_TRUE(isa<ConstantDataSequential>(NumTeamsStoreArg));
+ auto *NumTeamsStoreValue = cast<ConstantDataSequential>(NumTeamsStoreArg);
+ EXPECT_EQ(3U, NumTeamsStoreValue->getNumElements());
+ EXPECT_EQ(10U, NumTeamsStoreValue->getElementAsInteger(0));
+ EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(1));
+ EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(2));
+ Value *NumThreadsGetElemPtr = *std::next(KernelArgs->user_begin(), 2);
+ EXPECT_TRUE(isa<GetElementPtrInst>(NumThreadsGetElemPtr));
+ Value *NumThreadsStore = NumThreadsGetElemPtr->getUniqueUndroppableUser();
+ EXPECT_TRUE(isa<StoreInst>(NumThreadsStore));
+ Value *NumThreadsStoreArg =
+ cast<StoreInst>(NumThreadsStore)->getValueOperand();
+ EXPECT_TRUE(isa<ConstantDataSequential>(NumThreadsStoreArg));
+ auto *NumThreadsStoreValue = cast<ConstantDataSequential>(NumThreadsStoreArg);
+ EXPECT_EQ(3U, NumThreadsStoreValue->getNumElements());
+ EXPECT_EQ(20U, NumThreadsStoreValue->getElementAsInteger(0));
+ EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(1));
+ EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(2));
+
// Check the fallback call
BasicBlock *FallbackBlock = Branch->getSuccessor(0);
Iter = FallbackBlock->rbegin();
@@ -6343,15 +6386,16 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
F->getEntryBlock().getFirstInsertionPt());
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
/*Line=*/3, /*Count=*/0);
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
- /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
- /*MinThreads=*/0};
+ /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
- ASSERT_EXPECTED_INIT(
- OpenMPIRBuilder::InsertPointTy, AfterIP,
- OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, DefaultAttrs, CapturedArgs,
- GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+ ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
+ OMPBuilder.createTarget(
+ Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+ EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
+ GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
Builder.CreateRetVoid();
@@ -6435,6 +6479,204 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
auto *ExitBlock = EntryBlockBranch->getSuccessor(1);
EXPECT_EQ(ExitBlock->getName(), "worker.exit");
EXPECT_TRUE(isa<ReturnInst>(ExitBlock->getFirstNonPHI()));
+
+ // Check global exec_mode.
+ GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used");
+ EXPECT_NE(Used, nullptr);
+ Constant *UsedInit = Used->getInitializer();
+ EXPECT_NE(UsedInit, nullptr);
+ EXPECT_TRUE(isa<ConstantArray>(UsedInit));
+ auto *UsedInitData = cast<ConstantArray>(UsedInit);
+ EXPECT_EQ(1U, UsedInitData->getNumOperands());
+ Constant *ExecMode = UsedInitData->getOperand(0);
+ EXPECT_TRUE(isa<GlobalVariable>(ExecMode));
+ Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer();
+ EXPECT_NE(ExecModeValue, nullptr);
+ EXPECT_TRUE(isa<ConstantInt>(ExecModeValue));
+ EXPECT_EQ(OMP_TGT_EXEC_MODE_GENERIC,
+ cast<ConstantInt>(ExecModeValue)->getZExtValue());
+}
+
+TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ OpenMPIRBuilderConfig Config(/*IsTargetDevice=*/false, /*IsGPU=*/false,
+ /*OpenMPOffloadMandatory=*/false,
+ /*HasRequiresReverseOffload=*/false,
+ /*HasRequiresUnifiedAddress=*/false,
+ /*HasRequiresUnifiedSharedMemory=*/false,
+ /*HasRequiresDynamicAllocators=*/false);
+ OMPBuilder.setConfig(Config);
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+
+ auto BodyGenCB = [&](InsertPointTy,
+ InsertPointTy CodeGenIP) -> InsertPointTy {
+ Builder.restoreIP(CodeGenIP);
+ return Builder.saveIP();
+ };
+
+ auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
+ OpenMPIRBuilder::InsertPointTy,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ return Builder.saveIP();
+ };
+
+ SmallVector<Value *> Inputs;
+ OpenMPIRBuilder::MapInfosTy CombinedInfos;
+ auto GenMapInfoCB =
+ [&](OpenMPIRBuilder::InsertPointTy) -> OpenMPIRBuilder::MapInfosTy & {
+ return CombinedInfos;
+ };
+
+ TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
+ OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+ OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+ /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD,
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ RuntimeAttrs.LoopTripCount = Builder.getInt64(1000);
+
+ ASSERT_EXPECTED_INIT(
+ OpenMPIRBuilder::InsertPointTy, AfterIP,
+ OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
+ Builder.saveIP(), EntryInfo, DefaultAttrs,
+ RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB,
+ SimpleArgAccessorCB));
+ Builder.restoreIP(AfterIP);
+
+ OMPBuilder.finalize();
+ Builder.CreateRetVoid();
+
+ // Check the kernel launch sequence
+ auto Iter = F->getEntryBlock().rbegin();
+ EXPECT_TRUE(isa<BranchInst>(&*(Iter)));
+ BranchInst *Branch = dyn_cast<BranchInst>(&*(Iter));
+ EXPECT_TRUE(isa<CmpInst>(&*(++Iter)));
+ EXPECT_TRUE(isa<CallInst>(&*(++Iter)));
+ CallInst *Call = dyn_cast<CallInst>(&*(Iter));
+
+ // Check that the kernel launch function is called
+ Function *KernelLaunchFunc = Call->getCalledFunction();
+ EXPECT_NE(KernelLaunchFunc, nullptr);
+ StringRef FunctionName = KernelLaunchFunc->getName();
+ EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
+
+ // Check the trip count kernel argument (use number 5 starting from the end
+ // and counting the call to __tgt_target_kernel as the first use)
+ Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+ EXPECT_TRUE(KernelArgs->getNumUses() >= 6);
+ Value *TripCountGetElemPtr = *std::next(KernelArgs->user_begin(), 5);
+ EXPECT_TRUE(isa<GetElementPtrInst>(TripCountGetElemPtr));
+ Value *TripCountStore = TripCountGetElemPtr->getUniqueUndroppableUser();
+ EXPECT_TRUE(isa<StoreInst>(TripCountStore));
+ Value *TripCountStoreArg = cast<StoreInst>(TripCountStore)->getValueOperand();
+ EXPECT_TRUE(isa<ConstantInt>(TripCountStoreArg));
+ EXPECT_EQ(1000U, cast<ConstantInt>(TripCountStoreArg)->getZExtValue());
+
+ // Check the fallback call
+ BasicBlock *FallbackBlock = Branch->getSuccessor(0);
+ Iter = FallbackBlock->rbegin();
+ CallInst *FCall = dyn_cast<CallInst>(&*(++Iter));
+ // 'F' has a dummy DISubprogram which causes OutlinedFunc to also
+ // have a DISubprogram. In this case, the call to OutlinedFunc needs
+ // to have a debug loc, otherwise verifier will complain.
+ FCall->setDebugLoc(DL);
+ EXPECT_NE(FCall, nullptr);
+
+ // Check that the outlined function exists with the expected prefix
+ Function *OutlinedFunc = FCall->getCalledFunction();
+ EXPECT_NE(OutlinedFunc, nullptr);
+ StringRef FunctionName2 = OutlinedFunc->getName();
+ EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading"));
+
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+}
+
+TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.setConfig(
+ OpenMPIRBuilderConfig(/*IsTargetDevice=*/true, /*IsGPU=*/false,
+ /*OpenMPOffloadMandatory=*/false,
+ /*HasRequiresReverseOffload=*/false,
+ /*HasRequiresUnifiedAddress=*/false,
+ /*HasRequiresUnifiedSharedMemory=*/false,
+ /*HasRequiresDynamicAllocators=*/false));
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+ OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+
+ Function *OutlinedFn = nullptr;
+ SmallVector<Value *> CapturedArgs;
+
+ auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
+ OpenMPIRBuilder::InsertPointTy,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ return Builder.saveIP();
+ };
+
+ OpenMPIRBuilder::MapInfosTy CombinedInfos;
+ auto GenMapInfoCB =
+ [&](OpenMPIRBuilder::InsertPointTy) -> OpenMPIRBuilder::MapInfosTy & {
+ return CombinedInfos;
+ };
+
+ auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP)
+ -> OpenMPIRBuilder::InsertPointTy {
+ Builder.restoreIP(CodeGenIP);
+ OutlinedFn = CodeGenIP.getBlock()->getParent();
+ return Builder.saveIP();
+ };
+
+ IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(),
+ F->getEntryBlock().getFirstInsertionPt());
+ TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
+ /*Line=*/3, /*Count=*/0);
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+ OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+ /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD,
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+
+ ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
+ OMPBuilder.createTarget(
+ Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+ EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
+ GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+ Builder.restoreIP(AfterIP);
+
+ Builder.CreateRetVoid();
+ OMPBuilder.finalize();
+
+ // Check outlined function
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+ EXPECT_NE(OutlinedFn, nullptr);
+ EXPECT_NE(F, OutlinedFn);
+
+ EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage());
+ // Account for the "implicit" first argument.
+ EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3");
+ EXPECT_EQ(OutlinedFn->arg_size(), 1U);
+
+ // Check global exec_mode.
+ GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used");
+ EXPECT_NE(Used, nullptr);
+ Constant *UsedInit = Used->getInitializer();
+ EXPECT_NE(UsedInit, nullptr);
+ EXPECT_TRUE(isa<ConstantArray>(UsedInit));
+ auto *UsedInitData = cast<ConstantArray>(UsedInit);
+ EXPECT_EQ(1U, UsedInitData->getNumOperands());
+ Constant *ExecMode = UsedInitData->getOperand(0);
+ EXPECT_TRUE(isa<GlobalVariable>(ExecMode));
+ Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer();
+ EXPECT_NE(ExecModeValue, nullptr);
+ EXPECT_TRUE(isa<ConstantInt>(ExecModeValue));
+ EXPECT_EQ(OMP_TGT_EXEC_MODE_SPMD,
+ cast<ConstantInt>(ExecModeValue)->getZExtValue());
}
TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
@@ -6502,15 +6744,16 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
F->getEntryBlock().getFirstInsertionPt());
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
/*Line=*/3, /*Count=*/0);
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
- /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
- /*MinThreads=*/0};
+ /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
- ASSERT_EXPECTED_INIT(
- OpenMPIRBuilder::InsertPointTy, AfterIP,
- OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, DefaultAttrs, CapturedArgs,
- GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+ ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
+ OMPBuilder.createTarget(
+ Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+ EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
+ GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
Builder.CreateRetVoid();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 25b0ffe2ced6de..5c36187540690e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4115,10 +4115,12 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
allocaIP, codeGenIP);
};
- // TODO: Populate default attributes based on the construct and clauses.
+ // TODO: Populate default and runtime attributes based on the construct and
+ // clauses.
+ llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
- /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
- /*MinThreads=*/0};
+ /*ExecFlags=*/llvm::omp::OMP_TGT_EXEC_MODE_GENERIC, /*MaxTeams=*/{-1},
+ /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
llvm::SmallVector<llvm::Value *, 4> kernelInput;
for (size_t i = 0; i < mapVars.size(); ++i) {
@@ -4143,8 +4145,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
- defaultAttrs, kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds,
- targetOp.getNowait());
+ defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB, bodyCB,
+ argAccessorCB, dds, targetOp.getNowait());
if (failed(handleError(afterIP, opInst)))
return failure();
More information about the llvm-commits
mailing list