[Mlir-commits] [llvm] [mlir] [llvm][mlir][OMPIRBuilder] Translate omp.single's copyprivate (PR #80488)

Kareem Ergawy llvmlistbot at llvm.org
Wed Feb 21 06:06:44 PST 2024


================
@@ -3464,6 +3464,117 @@ TEST_F(OpenMPIRBuilderTest, SingleDirectiveNowait) {
   EXPECT_EQ(ExitBarrier, nullptr);
 }
 
+TEST_F(OpenMPIRBuilderTest, SingleDirectiveCopyPrivate) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+
+  AllocaInst *PrivAI = nullptr;
+
+  BasicBlock *EntryBB = nullptr;
+  BasicBlock *ThenBB = nullptr;
+
+  Value *CPVar = Builder.CreateAlloca(F->arg_begin()->getType());
+  Builder.CreateStore(F->arg_begin(), CPVar);
+
+  FunctionType *CopyFuncTy = FunctionType::get(
+      Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getPtrTy()}, false);
+  Function *CopyFunc =
+      Function::Create(CopyFuncTy, Function::PrivateLinkage, "copy_var", *M);
+
+  Value *DidIt = Builder.CreateAlloca(Type::getInt32Ty(Builder.getContext()));
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    if (AllocaIP.isSet())
+      Builder.restoreIP(AllocaIP);
+    else
+      Builder.SetInsertPoint(&*(F->getEntryBlock().getFirstInsertionPt()));
+    PrivAI = Builder.CreateAlloca(F->arg_begin()->getType());
+    Builder.CreateStore(F->arg_begin(), PrivAI);
+
+    llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock();
+    llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint();
+    EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst);
+
+    Builder.restoreIP(CodeGenIP);
+
+    // collect some info for checks later
+    ThenBB = Builder.GetInsertBlock();
+    EntryBB = ThenBB->getUniquePredecessor();
+
+    // simple instructions for body
+    Value *PrivLoad =
+        Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
+    Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
+  };
+
+  auto FiniCB = [&](InsertPointTy IP) {
+    BasicBlock *IPBB = IP.getBlock();
+    EXPECT_NE(IPBB->end(), IP.getPoint());
+  };
+
+  Builder.restoreIP(OMPBuilder.createSingle(Builder, BodyGenCB, FiniCB,
+                                            /*IsNowait*/ false, DidIt, {CPVar},
+                                            {CopyFunc}));
+  Value *EntryBBTI = EntryBB->getTerminator();
+  EXPECT_NE(EntryBBTI, nullptr);
+  EXPECT_TRUE(isa<BranchInst>(EntryBBTI));
+  BranchInst *EntryBr = cast<BranchInst>(EntryBB->getTerminator());
+  EXPECT_TRUE(EntryBr->isConditional());
+  EXPECT_EQ(EntryBr->getSuccessor(0), ThenBB);
+  BasicBlock *ExitBB = ThenBB->getUniqueSuccessor();
+  EXPECT_EQ(EntryBr->getSuccessor(1), ExitBB);
+
+  CmpInst *CondInst = cast<CmpInst>(EntryBr->getCondition());
+  EXPECT_TRUE(isa<CallInst>(CondInst->getOperand(0)));
+
+  CallInst *SingleEntryCI = cast<CallInst>(CondInst->getOperand(0));
+  EXPECT_EQ(SingleEntryCI->arg_size(), 2U);
+  EXPECT_EQ(SingleEntryCI->getCalledFunction()->getName(), "__kmpc_single");
+  EXPECT_TRUE(isa<GlobalVariable>(SingleEntryCI->getArgOperand(0)));
+
+  CallInst *SingleEndCI = nullptr;
+  for (auto &FI : *ThenBB) {
+    Instruction *Cur = &FI;
+    if (isa<CallInst>(Cur)) {
+      SingleEndCI = cast<CallInst>(Cur);
+      if (SingleEndCI->getCalledFunction()->getName() == "__kmpc_end_single")
+        break;
+      SingleEndCI = nullptr;
+    }
+  }
+  EXPECT_NE(SingleEndCI, nullptr);
+  EXPECT_EQ(SingleEndCI->arg_size(), 2U);
+  EXPECT_TRUE(isa<GlobalVariable>(SingleEndCI->getArgOperand(0)));
+  EXPECT_EQ(SingleEndCI->getArgOperand(1), SingleEntryCI->getArgOperand(1));
+
+  CallInst *CopyPrivateCI = nullptr;
+  bool FoundBarrier = false;
+  for (auto &FI : *ExitBB) {
+    Instruction *Cur = &FI;
+    if (auto *CI = dyn_cast<CallInst>(Cur)) {
+      if (CI->getCalledFunction()->getName() == "__kmpc_barrier")
+        FoundBarrier = true;
+      else if (CI->getCalledFunction()->getName() == "__kmpc_copyprivate")
+        CopyPrivateCI = CI;
+    }
+  }
----------------
ergawy wrote:

Similar suggestion here, I think we can check more strictly for the contents of the BB.

https://github.com/llvm/llvm-project/pull/80488


More information about the Mlir-commits mailing list