[llvm] [mlir] [llvm][mlir][OMPIRBuilder] Translate omp.single's copyprivate (PR #80488)
Leandro Lupori via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 22 12:17:16 PST 2024
================
@@ -3464,6 +3464,117 @@ TEST_F(OpenMPIRBuilderTest, SingleDirectiveNowait) {
EXPECT_EQ(ExitBarrier, nullptr);
}
+TEST_F(OpenMPIRBuilderTest, SingleDirectiveCopyPrivate) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+
+ OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+
+ AllocaInst *PrivAI = nullptr;
+
+ BasicBlock *EntryBB = nullptr;
+ BasicBlock *ThenBB = nullptr;
+
+ Value *CPVar = Builder.CreateAlloca(F->arg_begin()->getType());
+ Builder.CreateStore(F->arg_begin(), CPVar);
+
+ FunctionType *CopyFuncTy = FunctionType::get(
+ Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getPtrTy()}, false);
+ Function *CopyFunc =
+ Function::Create(CopyFuncTy, Function::PrivateLinkage, "copy_var", *M);
+
+ Value *DidIt = Builder.CreateAlloca(Type::getInt32Ty(Builder.getContext()));
+
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+ if (AllocaIP.isSet())
+ Builder.restoreIP(AllocaIP);
+ else
+ Builder.SetInsertPoint(&*(F->getEntryBlock().getFirstInsertionPt()));
+ PrivAI = Builder.CreateAlloca(F->arg_begin()->getType());
+ Builder.CreateStore(F->arg_begin(), PrivAI);
+
+ llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock();
+ llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint();
+ EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst);
+
+ Builder.restoreIP(CodeGenIP);
+
+ // collect some info for checks later
+ ThenBB = Builder.GetInsertBlock();
+ EntryBB = ThenBB->getUniquePredecessor();
+
+ // simple instructions for body
+ Value *PrivLoad =
+ Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
+ Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
+ };
+
+ auto FiniCB = [&](InsertPointTy IP) {
+ BasicBlock *IPBB = IP.getBlock();
+ EXPECT_NE(IPBB->end(), IP.getPoint());
+ };
+
+ Builder.restoreIP(OMPBuilder.createSingle(Builder, BodyGenCB, FiniCB,
+ /*IsNowait*/ false, DidIt, {CPVar},
+ {CopyFunc}));
+ Value *EntryBBTI = EntryBB->getTerminator();
+ EXPECT_NE(EntryBBTI, nullptr);
+ EXPECT_TRUE(isa<BranchInst>(EntryBBTI));
+ BranchInst *EntryBr = cast<BranchInst>(EntryBB->getTerminator());
+ EXPECT_TRUE(EntryBr->isConditional());
+ EXPECT_EQ(EntryBr->getSuccessor(0), ThenBB);
+ BasicBlock *ExitBB = ThenBB->getUniqueSuccessor();
+ EXPECT_EQ(EntryBr->getSuccessor(1), ExitBB);
+
+ CmpInst *CondInst = cast<CmpInst>(EntryBr->getCondition());
+ EXPECT_TRUE(isa<CallInst>(CondInst->getOperand(0)));
+
+ CallInst *SingleEntryCI = cast<CallInst>(CondInst->getOperand(0));
+ EXPECT_EQ(SingleEntryCI->arg_size(), 2U);
+ EXPECT_EQ(SingleEntryCI->getCalledFunction()->getName(), "__kmpc_single");
+ EXPECT_TRUE(isa<GlobalVariable>(SingleEntryCI->getArgOperand(0)));
+
+ CallInst *SingleEndCI = nullptr;
+ for (auto &FI : *ThenBB) {
+ Instruction *Cur = &FI;
+ if (isa<CallInst>(Cur)) {
+ SingleEndCI = cast<CallInst>(Cur);
+ if (SingleEndCI->getCalledFunction()->getName() == "__kmpc_end_single")
+ break;
+ SingleEndCI = nullptr;
+ }
+ }
----------------
luporl wrote:
Now the unit test is checking every instruction. Maybe it's a bit too much now, but it shouldn't be an issue.
https://github.com/llvm/llvm-project/pull/80488
More information about the llvm-commits
mailing list