[llvm-branch-commits] [llvm] [mlir] [OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode (PR #116051)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Nov 13 05:41:43 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir-llvm
Author: Sergio Afonso (skatrak)
<details>
<summary>Changes</summary>
This patch introduces a `TargetKernelRuntimeAttrs` structure to hold host- evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values passed to the runtime kernel offloading call.
Additionally, `createTarget` is extended to take an `IsSPMD` flag, used to influence target device code generation.
---
Patch is 31.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116051.diff
4 Files Affected:
- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+25-1)
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+118-19)
- (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+271-10)
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+6-4)
``````````diff
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index da450ef5adbc14..a85f41e586c514 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2237,6 +2237,26 @@ class OpenMPIRBuilder {
int32_t MinThreads = 1;
};
+ /// Container to pass LLVM IR runtime values or constants related to the
+ /// number of teams and threads with which the kernel must be launched, as
+ /// well as the trip count of the SPMD loop, if it is an SPMD kernel. These
+ /// must be defined in the host prior to the call to the kernel launch OpenMP
+ /// RTL function.
+ struct TargetKernelRuntimeAttrs {
+ SmallVector<Value *, 3> MaxTeams = {nullptr};
+ Value *MinTeams = nullptr;
+ SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
+ SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
+
+ /// 'parallel' construct 'num_threads' clause value, if present and it is a
+ /// target SPMD kernel.
+ Value *MaxThreads = nullptr;
+
+ /// Total number of iterations of the target SPMD kernel or null if it is a
+ /// generic kernel.
+ Value *LoopTripCount = nullptr;
+ };
+
/// Data structure that contains the needed information to construct the
/// kernel args vector.
struct TargetKernelArgs {
@@ -2905,11 +2925,14 @@ class OpenMPIRBuilder {
///
/// \param Loc where the target data construct was encountered.
/// \param IsOffloadEntry whether it is an offload entry.
+ /// \param IsSPMD whether it is a target SPMD kernel.
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
/// \param DefaultAttrs Structure containing the default numbers of threads
/// and teams to launch the kernel with.
+ /// \param RuntimeAttrs Structure containing the runtime numbers of threads
+ /// and teams to launch the kernel with.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
/// \param BodyGenCB Callback that will generate the region code.
@@ -2919,11 +2942,12 @@ class OpenMPIRBuilder {
// dependency information as passed in the depend clause
// \param HasNowait Whether the target construct has a `nowait` clause or not.
InsertPointOrErrorTy createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry,
+ const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
TargetBodyGenCallbackTy BodyGenCB,
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 302d363965c940..f847f60386df85 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6727,8 +6727,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
}
+static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
+ Module &M) {
+ if (List.empty())
+ return;
+
+ Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
+
+ // Convert List to what ConstantArray needs.
+ SmallVector<Constant *, 8> UsedArray;
+ UsedArray.reserve(List.size());
+ for (auto Item : List)
+ UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+ cast<Constant>(&*Item), PtrTy));
+
+ ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
+ auto *GV =
+ new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
+ llvm::ConstantArray::get(ArrTy, UsedArray), Name);
+
+ GV->setSection("llvm.metadata");
+}
+
+static void
+emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+ StringRef FunctionName, OMPTgtExecModeFlags Mode,
+ std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
+ auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
+ auto *GVMode = new llvm::GlobalVariable(
+ OMPBuilder.M, Int8Ty, /*isConstant=*/true,
+ llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
+ Twine(FunctionName, "_exec_mode"));
+ GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
+ LLVMCompilerUsed.emplace_back(GVMode);
+}
+
static Expected<Function *> createOutlinedFunction(
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6758,6 +6793,27 @@ static Expected<Function *> createOutlinedFunction(
auto Func =
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
+ // Forward target-cpu and target-features function attributes from the
+ // original function to the new outlined function.
+ Function *ParentFn = Builder.GetInsertBlock()->getParent();
+
+ auto TargetCpuAttr = ParentFn->getFnAttribute("target-cpu");
+ if (TargetCpuAttr.isStringAttribute())
+ Func->addFnAttr(TargetCpuAttr);
+
+ auto TargetFeaturesAttr = ParentFn->getFnAttribute("target-features");
+ if (TargetFeaturesAttr.isStringAttribute())
+ Func->addFnAttr(TargetFeaturesAttr);
+
+ if (OMPBuilder.Config.isTargetDevice()) {
+ std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
+ emitExecutionMode(OMPBuilder, Builder, FuncName,
+ IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
+ : OMP_TGT_EXEC_MODE_GENERIC,
+ LLVMCompilerUsed);
+ emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
+ }
+
// Save insert point.
IRBuilder<>::InsertPointGuard IPG(Builder);
// If there's a DISubprogram associated with current function, then
@@ -6798,7 +6854,7 @@ static Expected<Function *> createOutlinedFunction(
// Insert target init call in the device compilation pass.
if (OMPBuilder.Config.isTargetDevice())
Builder.restoreIP(
- OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
+ OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs));
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
@@ -6995,7 +7051,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
static Error emitTargetOutlinedFunction(
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
- TargetRegionEntryInfo &EntryInfo,
+ bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
Function *&OutlinedFn, Constant *&OutlinedFnID,
SmallVectorImpl<Value *> &Inputs,
@@ -7004,7 +7060,7 @@ static Error emitTargetOutlinedFunction(
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
[&](StringRef EntryFnName) {
- return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+ return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs,
EntryFnName, Inputs, CBFunc,
ArgAccessorFuncCB);
};
@@ -7304,6 +7360,7 @@ static void
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+ const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
Function *OutlinedFn, Constant *OutlinedFnID,
SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7385,11 +7442,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
/*ForEndCall=*/false);
SmallVector<Value *, 3> NumTeamsC;
+ for (auto [DefaultVal, RuntimeVal] :
+ zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+ NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
+
+ // Calculate number of threads: 0 if no clauses specified, otherwise it is the
+ // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+ if (Clause)
+ Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+ /*isSigned=*/false);
+ return Clause;
+ };
+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+ if (Clause)
+ Result = Result
+ ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
+ Result, Clause)
+ : Clause;
+ };
+
+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
SmallVector<Value *, 3> NumThreadsC;
- for (auto V : DefaultAttrs.MaxTeams)
- NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
- for (auto V : DefaultAttrs.MaxThreads)
- NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+ Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
+ ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+ : nullptr;
+
+ for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
+ RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+ Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+ Value *NumThreads = InitMaxThreadsClause(TargetVal);
+
+ CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+ CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+
+ NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+ }
unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
@@ -7398,14 +7487,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);
- // TODO: Use correct NumIterations
- Value *NumIterations = Builder.getInt64(0);
+
+ Value *TripCount = RuntimeAttrs.LoopTripCount
+ ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+ Builder.getInt64Ty(),
+ /*isSigned=*/false)
+ : Builder.getInt64(0);
+
// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);
- KArgs = OpenMPIRBuilder::TargetKernelArgs(
- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
- DynCGGroupMem, HasNoWait);
+ KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+ NumTeamsC, NumThreadsC,
+ DynCGGroupMem, HasNoWait);
// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
@@ -7427,13 +7521,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
}
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
- InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
+ const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
+ InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
+ TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies, bool HasNowait) {
+ assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
+ "trip count not expected if IsSPMD=false");
if (!updateToLocation(Loc))
return InsertPointTy();
@@ -7446,16 +7544,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// the target region itself is generated using the callbacks CBFunc
// and ArgAccessorFuncCB
if (Error Err = emitTargetOutlinedFunction(
- *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
- OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
+ *this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs,
+ OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
return Err;
// If we are not on the target device, then we need to generate code
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
- emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
- OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
+ emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+ OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
+ HasNowait);
return Builder.saveIP();
}
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index b0688d6215e42d..63be7e775b83c9 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6122,8 +6122,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
OpenMPIRBuilderConfig Config(false, false, false, false, false, false, false);
OMPBuilder.setConfig(Config);
F->setName("func");
+ F->addFnAttr("target-cpu", "x86-64");
+ F->addFnAttr("target-features", "+mmx,+sse");
IRBuilder<> Builder(BB);
- auto Int32Ty = Builder.getInt32Ty();
+ auto *Int32Ty = Builder.getInt32Ty();
AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr");
AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr");
@@ -6183,11 +6185,15 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
- /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
- OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
- Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
- GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ /*MaxTeams=*/{10}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+ RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20);
+ RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
+ RuntimeAttrs.MaxThreads = Builder.getInt32(40);
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, Builder.saveIP(),
+ Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
+ GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
OMPBuilder.finalize();
@@ -6207,6 +6213,43 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
StringRef FunctionName = KernelLaunchFunc->getName();
EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
+ // Check num_teams and num_threads in call arguments
+ EXPECT_TRUE(Call->arg_size() >= 4);
+ Value *NumTeamsArg = Call->getArgOperand(2);
+ EXPECT_TRUE(isa<ConstantInt>(NumTeamsArg));
+ EXPECT_EQ(10U, cast<ConstantInt>(NumTeamsArg)->getZExtValue());
+ Value *NumThreadsArg = Call->getArgOperand(3);
+ EXPECT_TRUE(isa<ConstantInt>(NumThreadsArg));
+ EXPECT_EQ(20U, cast<ConstantInt>(NumThreadsArg)->getZExtValue());
+
+ // Check num_teams and num_threads kernel arguments (use number 5 starting
+ // from the end and counting the call to __tgt_target_kernel as the first use)
+ Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+ EXPECT_TRUE(KernelArgs->getNumUses() >= 4);
+ Value *NumTeamsGetElemPtr = *std::next(KernelArgs->user_begin(), 3);
+ EXPECT_TRUE(isa<GetElementPtrInst>(NumTeamsGetElemPtr));
+ Value *NumTeamsStore = NumTeamsGetElemPtr->getUniqueUndroppableUser();
+ EXPECT_TRUE(isa<StoreInst>(NumTeamsStore));
+ Value *NumTeamsStoreArg = cast<StoreInst>(NumTeamsStore)->getValueOperand();
+ EXPECT_TRUE(isa<ConstantDataSequential>(NumTeamsStoreArg));
+ auto *NumTeamsStoreValue = cast<ConstantDataSequential>(NumTeamsStoreArg);
+ EXPECT_EQ(3U, NumTeamsStoreValue->getNumElements());
+ EXPECT_EQ(10U, NumTeamsStoreValue->getElementAsInteger(0));
+ EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(1));
+ EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(2));
+ Value *NumThreadsGetElemPtr = *std::next(KernelArgs->user_begin(), 2);
+ EXPECT_TRUE(isa<GetElementPtrInst>(NumThreadsGetElemPtr));
+ Value *NumThreadsStore = NumThreadsGetElemPtr->getUniqueUndroppableUser();
+ EXPECT_TRUE(isa<StoreInst>(NumThreadsStore));
+ Value *NumThreadsStoreArg =
+ cast<StoreInst>(NumThreadsStore)->getValueOperand();
+ EXPECT_TRUE(isa<ConstantDataSequential>(NumThreadsStoreArg));
+ auto *NumThreadsStoreValue = cast<ConstantDataSequential>(NumThreadsStoreArg);
+ EXPECT_EQ(3U, NumThreadsStoreValue->getNumElements());
+ EXPECT_EQ(20U, NumThreadsStoreValue->getElementAsInteger(0));
+ EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(1));
+ EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(2));
+
// Check the fallback call
BasicBlock *FallbackBlock = Branch->getSuccessor(0);
Iter = FallbackBlock->rbegin();
@@ -6228,6 +6271,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
StringRef FunctionName2 = OutlinedFunc->getName();
EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading"));
+ // Check that target-cpu and target-features were propagated to the outlined
+ // function
+ EXPECT_EQ(OutlinedFunc->getFnAttribute("target-cpu"),
+ F->getFnAttribute("target-cpu"));
+ EXPECT_EQ(OutlinedFunc->getFnAttribute("target-features"),
+ F->getFnAttribute("target-features"));
+
EXPECT_FALSE(verifyModule(*M, &errs()));
}
@@ -6238,6 +6288,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
OMPBuilder.initialize();
F->setName("func");
+ F->addFnAttr("target-cpu", "gfx90a");
+ F->addFnAttr("target-features", "+gfx9-insts,+wavefrontsize64");
IRBuilder<> Builder(BB);
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
@@ -6297,9 +6349,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
- CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+ EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+ BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
@@ -6312,6 +6366,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
Function *OutlinedFn = TargetStore->getFunction();
EXPECT_NE(F, OutlinedFn);
+ // Check that target-cpu and target-features were propagated to the outlined
+ // function
+ EXPECT_EQ(OutlinedFn->getFnAttribute("target-cpu"),
+ ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/116051
More information about the llvm-branch-commits
mailing list