[Mlir-commits] [llvm] [mlir] [llvm][mlir][OMPIRBuilder] Translate omp.single's copyprivate (PR #80488)

Leandro Lupori llvmlistbot at llvm.org
Thu Feb 22 12:12:28 PST 2024


https://github.com/luporl updated https://github.com/llvm/llvm-project/pull/80488

>From 3c6163c21ff4d3c40d7b533a718d55d596347ce9 Mon Sep 17 00:00:00 2001
From: Leandro Lupori <leandro.lupori at linaro.org>
Date: Fri, 2 Feb 2024 17:16:34 -0300
Subject: [PATCH 1/2] [llvm][mlir][OMPIRBuilder] Translate omp.single's
 copyprivate

Use the new copyprivate list from omp.single to emit calls to
__kmpc_copyprivate, during the creation of the single operation
in OMPIRBuilder.

This is patch 4 of 4, to add support for COPYPRIVATE in Flang.
Original PR: https://github.com/llvm/llvm-project/pull/73128
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |   6 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |  23 +++-
 .../Frontend/OpenMPIRBuilderTest.cpp          | 111 ++++++++++++++++++
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  20 +++-
 mlir/test/Target/LLVMIR/openmp-llvm.mlir      |  32 +++++
 5 files changed, 187 insertions(+), 5 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 2288969ecc95c4..5469e4cecb6735 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1828,12 +1828,16 @@ class OpenMPIRBuilder {
   /// \param FiniCB Callback to finalize variable copies.
   /// \param IsNowait If false, a barrier is emitted.
   /// \param DidIt Local variable used as a flag to indicate 'single' thread
+  /// \param CPVars copyprivate variables.
+  /// \param CPFuncs copy functions to use for each copyprivate variable.
   ///
   /// \returns The insertion position *after* the single call.
   InsertPointTy createSingle(const LocationDescription &Loc,
                              BodyGenCallbackTy BodyGenCB,
                              FinalizeCallbackTy FiniCB, bool IsNowait,
-                             llvm::Value *DidIt);
+                             llvm::Value *DidIt,
+                             ArrayRef<llvm::Value *> CPVars = {},
+                             ArrayRef<llvm::Function *> CPFuncs = {});
 
   /// Generator for '#omp master'
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 02b333e9ccd567..347bc838d89165 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4002,7 +4002,8 @@ OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
 
 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
-    FinalizeCallbackTy FiniCB, bool IsNowait, llvm::Value *DidIt) {
+    FinalizeCallbackTy FiniCB, bool IsNowait, llvm::Value *DidIt,
+    ArrayRef<llvm::Value *> CPVars, ArrayRef<llvm::Function *> CPFuncs) {
 
   if (!updateToLocation(Loc))
     return Loc.IP;
@@ -4025,17 +4026,33 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_single);
   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
 
+  auto FiniCBWrapper = [&](InsertPointTy IP) {
+    FiniCB(IP);
+
+    if (DidIt)
+      Builder.CreateStore(Builder.getInt32(1), DidIt);
+  };
+
   // generates the following:
   // if (__kmpc_single()) {
   //		.... single region ...
   // 		__kmpc_end_single
   // }
+  // __kmpc_copyprivate
   // __kmpc_barrier
 
-  EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
+  EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCBWrapper,
                        /*Conditional*/ true,
                        /*hasFinalize*/ true);
-  if (!IsNowait)
+
+  if (DidIt) {
+    for (size_t I = 0, E = CPVars.size(); I < E; ++I)
+      // NOTE BufSize is currently unused, so just pass 0.
+      createCopyPrivate(LocationDescription(Builder.saveIP(), Loc.DL),
+                        /*BufSize=*/ConstantInt::get(Int64, 0), CPVars[I],
+                        CPFuncs[I], DidIt);
+    // NOTE __kmpc_copyprivate already inserts a barrier
+  } else if (!IsNowait)
     createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
                   omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
                   /* CheckCancelFlag */ false);
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index e79d0bb2f65aea..0eb1039aa442ce 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -3464,6 +3464,117 @@ TEST_F(OpenMPIRBuilderTest, SingleDirectiveNowait) {
   EXPECT_EQ(ExitBarrier, nullptr);
 }
 
+TEST_F(OpenMPIRBuilderTest, SingleDirectiveCopyPrivate) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+
+  AllocaInst *PrivAI = nullptr;
+
+  BasicBlock *EntryBB = nullptr;
+  BasicBlock *ThenBB = nullptr;
+
+  Value *CPVar = Builder.CreateAlloca(F->arg_begin()->getType());
+  Builder.CreateStore(F->arg_begin(), CPVar);
+
+  FunctionType *CopyFuncTy = FunctionType::get(
+      Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getPtrTy()}, false);
+  Function *CopyFunc =
+      Function::Create(CopyFuncTy, Function::PrivateLinkage, "copy_var", *M);
+
+  Value *DidIt = Builder.CreateAlloca(Type::getInt32Ty(Builder.getContext()));
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    if (AllocaIP.isSet())
+      Builder.restoreIP(AllocaIP);
+    else
+      Builder.SetInsertPoint(&*(F->getEntryBlock().getFirstInsertionPt()));
+    PrivAI = Builder.CreateAlloca(F->arg_begin()->getType());
+    Builder.CreateStore(F->arg_begin(), PrivAI);
+
+    llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock();
+    llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint();
+    EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst);
+
+    Builder.restoreIP(CodeGenIP);
+
+    // collect some info for checks later
+    ThenBB = Builder.GetInsertBlock();
+    EntryBB = ThenBB->getUniquePredecessor();
+
+    // simple instructions for body
+    Value *PrivLoad =
+        Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use");
+    Builder.CreateICmpNE(F->arg_begin(), PrivLoad);
+  };
+
+  auto FiniCB = [&](InsertPointTy IP) {
+    BasicBlock *IPBB = IP.getBlock();
+    EXPECT_NE(IPBB->end(), IP.getPoint());
+  };
+
+  Builder.restoreIP(OMPBuilder.createSingle(Builder, BodyGenCB, FiniCB,
+                                            /*IsNowait*/ false, DidIt, {CPVar},
+                                            {CopyFunc}));
+  Value *EntryBBTI = EntryBB->getTerminator();
+  EXPECT_NE(EntryBBTI, nullptr);
+  EXPECT_TRUE(isa<BranchInst>(EntryBBTI));
+  BranchInst *EntryBr = cast<BranchInst>(EntryBB->getTerminator());
+  EXPECT_TRUE(EntryBr->isConditional());
+  EXPECT_EQ(EntryBr->getSuccessor(0), ThenBB);
+  BasicBlock *ExitBB = ThenBB->getUniqueSuccessor();
+  EXPECT_EQ(EntryBr->getSuccessor(1), ExitBB);
+
+  CmpInst *CondInst = cast<CmpInst>(EntryBr->getCondition());
+  EXPECT_TRUE(isa<CallInst>(CondInst->getOperand(0)));
+
+  CallInst *SingleEntryCI = cast<CallInst>(CondInst->getOperand(0));
+  EXPECT_EQ(SingleEntryCI->arg_size(), 2U);
+  EXPECT_EQ(SingleEntryCI->getCalledFunction()->getName(), "__kmpc_single");
+  EXPECT_TRUE(isa<GlobalVariable>(SingleEntryCI->getArgOperand(0)));
+
+  CallInst *SingleEndCI = nullptr;
+  for (auto &FI : *ThenBB) {
+    Instruction *Cur = &FI;
+    if (isa<CallInst>(Cur)) {
+      SingleEndCI = cast<CallInst>(Cur);
+      if (SingleEndCI->getCalledFunction()->getName() == "__kmpc_end_single")
+        break;
+      SingleEndCI = nullptr;
+    }
+  }
+  EXPECT_NE(SingleEndCI, nullptr);
+  EXPECT_EQ(SingleEndCI->arg_size(), 2U);
+  EXPECT_TRUE(isa<GlobalVariable>(SingleEndCI->getArgOperand(0)));
+  EXPECT_EQ(SingleEndCI->getArgOperand(1), SingleEntryCI->getArgOperand(1));
+
+  CallInst *CopyPrivateCI = nullptr;
+  bool FoundBarrier = false;
+  for (auto &FI : *ExitBB) {
+    Instruction *Cur = &FI;
+    if (auto *CI = dyn_cast<CallInst>(Cur)) {
+      if (CI->getCalledFunction()->getName() == "__kmpc_barrier")
+        FoundBarrier = true;
+      else if (CI->getCalledFunction()->getName() == "__kmpc_copyprivate")
+        CopyPrivateCI = CI;
+    }
+  }
+  EXPECT_FALSE(FoundBarrier);
+  EXPECT_NE(CopyPrivateCI, nullptr);
+  EXPECT_EQ(CopyPrivateCI->arg_size(), 6U);
+  EXPECT_TRUE(isa<AllocaInst>(CopyPrivateCI->getArgOperand(3)));
+  EXPECT_EQ(CopyPrivateCI->getArgOperand(3), CPVar);
+  EXPECT_TRUE(isa<Function>(CopyPrivateCI->getArgOperand(4)));
+  EXPECT_EQ(CopyPrivateCI->getArgOperand(4), CopyFunc);
+  EXPECT_TRUE(isa<LoadInst>(CopyPrivateCI->getArgOperand(5)));
+  LoadInst *DidItLI = cast<LoadInst>(CopyPrivateCI->getArgOperand(5));
+  EXPECT_EQ(DidItLI->getOperand(0), DidIt);
+}
+
 TEST_F(OpenMPIRBuilderTest, OMPAtomicReadFlt) {
   OpenMPIRBuilder OMPBuilder(*M);
   OMPBuilder.initialize();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 78a2ad76a1e3b8..43b6155cb72a3b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -656,8 +656,26 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
                         moduleTranslation, bodyGenStatus);
   };
   auto finiCB = [&](InsertPointTy codeGenIP) {};
+
+  // Handle copyprivate
+  Operation::operand_range cpVars = singleOp.getCopyprivateVars();
+  std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateFuncs();
+  llvm::SmallVector<llvm::Value *> llvmCPVars;
+  llvm::SmallVector<llvm::Function *> llvmCPFuncs;
+  for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
+    llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
+    auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
+        singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
+    llvmCPFuncs.push_back(
+        moduleTranslation.lookupFunction(llvmFuncOp.getName()));
+  }
+  llvm::Value *didIt = nullptr;
+  if (!llvmCPVars.empty())
+    didIt = builder.CreateAlloca(llvm::Type::getInt32Ty(builder.getContext()));
+
   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createSingle(
-      ompLoc, bodyCB, finiCB, singleOp.getNowait(), /*DidIt=*/nullptr));
+      ompLoc, bodyCB, finiCB, singleOp.getNowait(), didIt, llvmCPVars,
+      llvmCPFuncs));
   return bodyGenStatus;
 }
 
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 036367b262f07d..49103e7e1429cd 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2186,6 +2186,38 @@ llvm.func @single_nowait(%x: i32, %y: i32, %zaddr: !llvm.ptr) {
 
 // -----
 
+llvm.func @copy_i32(!llvm.ptr, !llvm.ptr)
+llvm.func @copy_f32(!llvm.ptr, !llvm.ptr)
+
+// CHECK-LABEL: @single_copyprivate
+// CHECK-SAME: (ptr %[[ip:.*]], ptr %[[fp:.*]])
+llvm.func @single_copyprivate(%ip: !llvm.ptr, %fp: !llvm.ptr) {
+  // CHECK: call i32 @__kmpc_single
+  omp.single copyprivate(%ip -> @copy_i32 : !llvm.ptr, %fp -> @copy_f32 : !llvm.ptr) {
+    // CHECK: %[[i:.*]] = load i32, ptr %[[ip]]
+    %i = llvm.load %ip : !llvm.ptr -> i32
+    // CHECK: %[[i2:.*]] = add i32 %[[i]], %[[i]]
+    %i2 = llvm.add %i, %i : i32
+    // CHECK: store i32 %[[i2]], ptr %[[ip]]
+    llvm.store %i2, %ip : i32, !llvm.ptr
+    // CHECK: %[[f:.*]] = load float, ptr %[[fp]]
+    %f = llvm.load %fp : !llvm.ptr -> f32
+    // CHECK: %[[f2:.*]] = fadd float %[[f]], %[[f]]
+    %f2 = llvm.fadd %f, %f : f32
+    // CHECK: store float %[[f2]], ptr %[[fp]]
+    llvm.store %f2, %fp : f32, !llvm.ptr
+    // CHECK: call void @__kmpc_end_single
+    // CHECK: call void @__kmpc_copyprivate({{.*}}, ptr %[[ip]], ptr @copy_i32, {{.*}})
+    // CHECK: call void @__kmpc_copyprivate({{.*}}, ptr %[[fp]], ptr @copy_f32, {{.*}})
+    // CHECK-NOT: call void @__kmpc_barrier
+    omp.terminator
+  }
+  // CHECK: ret void
+  llvm.return
+}
+
+// -----
+
 // CHECK: @_QFsubEx = internal global i32 undef
 // CHECK: @_QFsubEx.cache = common global ptr null
 

>From e68238f9b5dec15d05a0e52c8b4d22d756804628 Mon Sep 17 00:00:00 2001
From: Leandro Lupori <leandro.lupori at linaro.org>
Date: Thu, 22 Feb 2024 19:23:20 +0000
Subject: [PATCH 2/2] Address review's comments

---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  2 -
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 10 +-
 .../Frontend/OpenMPIRBuilderTest.cpp          | 96 +++++++++++++------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  6 +-
 4 files changed, 72 insertions(+), 42 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 5469e4cecb6735..bf0dae270f4320 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1827,7 +1827,6 @@ class OpenMPIRBuilder {
   /// \param BodyGenCB Callback that will generate the region code.
   /// \param FiniCB Callback to finalize variable copies.
   /// \param IsNowait If false, a barrier is emitted.
-  /// \param DidIt Local variable used as a flag to indicate 'single' thread
   /// \param CPVars copyprivate variables.
   /// \param CPFuncs copy functions to use for each copyprivate variable.
   ///
@@ -1835,7 +1834,6 @@ class OpenMPIRBuilder {
   InsertPointTy createSingle(const LocationDescription &Loc,
                              BodyGenCallbackTy BodyGenCB,
                              FinalizeCallbackTy FiniCB, bool IsNowait,
-                             llvm::Value *DidIt,
                              ArrayRef<llvm::Value *> CPVars = {},
                              ArrayRef<llvm::Function *> CPFuncs = {});
 
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 347bc838d89165..258ba2f7abe34b 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4002,14 +4002,16 @@ OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
 
 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
-    FinalizeCallbackTy FiniCB, bool IsNowait, llvm::Value *DidIt,
-    ArrayRef<llvm::Value *> CPVars, ArrayRef<llvm::Function *> CPFuncs) {
+    FinalizeCallbackTy FiniCB, bool IsNowait, ArrayRef<llvm::Value *> CPVars,
+    ArrayRef<llvm::Function *> CPFuncs) {
 
   if (!updateToLocation(Loc))
     return Loc.IP;
 
-  // If needed (i.e. not null), initialize `DidIt` with 0
-  if (DidIt) {
+  // If needed allocate and initialize `DidIt` with 0
+  llvm::Value *DidIt = nullptr;
+  if (!CPVars.empty()) {
+    DidIt = Builder.CreateAlloca(llvm::Type::getInt32Ty(Builder.getContext()));
     Builder.CreateStore(Builder.getInt32(0), DidIt);
   }
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 0eb1039aa442ce..359613df0b25a0 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -3327,8 +3327,8 @@ TEST_F(OpenMPIRBuilderTest, SingleDirective) {
     EXPECT_NE(IPBB->end(), IP.getPoint());
   };
 
-  Builder.restoreIP(OMPBuilder.createSingle(
-      Builder, BodyGenCB, FiniCB, /*IsNowait*/ false, /*DidIt*/ nullptr));
+  Builder.restoreIP(
+      OMPBuilder.createSingle(Builder, BodyGenCB, FiniCB, /*IsNowait*/ false));
   Value *EntryBBTI = EntryBB->getTerminator();
   EXPECT_NE(EntryBBTI, nullptr);
   EXPECT_TRUE(isa<BranchInst>(EntryBBTI));
@@ -3417,8 +3417,8 @@ TEST_F(OpenMPIRBuilderTest, SingleDirectiveNowait) {
     EXPECT_NE(IPBB->end(), IP.getPoint());
   };
 
-  Builder.restoreIP(OMPBuilder.createSingle(
-      Builder, BodyGenCB, FiniCB, /*IsNowait*/ true, /*DidIt*/ nullptr));
+  Builder.restoreIP(
+      OMPBuilder.createSingle(Builder, BodyGenCB, FiniCB, /*IsNowait*/ true));
   Value *EntryBBTI = EntryBB->getTerminator();
   EXPECT_NE(EntryBBTI, nullptr);
   EXPECT_TRUE(isa<BranchInst>(EntryBBTI));
@@ -3464,6 +3464,26 @@ TEST_F(OpenMPIRBuilderTest, SingleDirectiveNowait) {
   EXPECT_EQ(ExitBarrier, nullptr);
 }
 
+// Helper class to check each instruction of a BB.
+class BBInstIter {
+  BasicBlock *BB;
+  BasicBlock::iterator BBI;
+
+public:
+  BBInstIter(BasicBlock *BB) : BB(BB), BBI(BB->begin()) {}
+
+  bool hasNext() const { return BBI != BB->end(); }
+
+  template <typename InstTy> InstTy *next() {
+    if (!hasNext())
+      return nullptr;
+    Instruction *Cur = &*BBI++;
+    if (!isa<InstTy>(Cur))
+      return nullptr;
+    return cast<InstTy>(Cur);
+  }
+};
+
 TEST_F(OpenMPIRBuilderTest, SingleDirectiveCopyPrivate) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);
@@ -3486,8 +3506,6 @@ TEST_F(OpenMPIRBuilderTest, SingleDirectiveCopyPrivate) {
   Function *CopyFunc =
       Function::Create(CopyFuncTy, Function::PrivateLinkage, "copy_var", *M);
 
-  Value *DidIt = Builder.CreateAlloca(Type::getInt32Ty(Builder.getContext()));
-
   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
     if (AllocaIP.isSet())
       Builder.restoreIP(AllocaIP);
@@ -3514,11 +3532,12 @@ TEST_F(OpenMPIRBuilderTest, SingleDirectiveCopyPrivate) {
 
   auto FiniCB = [&](InsertPointTy IP) {
     BasicBlock *IPBB = IP.getBlock();
+    // IP must be before the unconditional branch to ExitBB
     EXPECT_NE(IPBB->end(), IP.getPoint());
   };
 
   Builder.restoreIP(OMPBuilder.createSingle(Builder, BodyGenCB, FiniCB,
-                                            /*IsNowait*/ false, DidIt, {CPVar},
+                                            /*IsNowait*/ false, {CPVar},
                                             {CopyFunc}));
   Value *EntryBBTI = EntryBB->getTerminator();
   EXPECT_NE(EntryBBTI, nullptr);
@@ -3537,33 +3556,47 @@ TEST_F(OpenMPIRBuilderTest, SingleDirectiveCopyPrivate) {
   EXPECT_EQ(SingleEntryCI->getCalledFunction()->getName(), "__kmpc_single");
   EXPECT_TRUE(isa<GlobalVariable>(SingleEntryCI->getArgOperand(0)));
 
-  CallInst *SingleEndCI = nullptr;
-  for (auto &FI : *ThenBB) {
-    Instruction *Cur = &FI;
-    if (isa<CallInst>(Cur)) {
-      SingleEndCI = cast<CallInst>(Cur);
-      if (SingleEndCI->getCalledFunction()->getName() == "__kmpc_end_single")
-        break;
-      SingleEndCI = nullptr;
-    }
-  }
+  // check ThenBB
+  BBInstIter ThenBBI(ThenBB);
+  // load PrivAI
+  auto *PrivLI = ThenBBI.next<LoadInst>();
+  EXPECT_NE(PrivLI, nullptr);
+  EXPECT_EQ(PrivLI->getPointerOperand(), PrivAI);
+  // icmp
+  EXPECT_TRUE(ThenBBI.next<ICmpInst>());
+  // store 1, DidIt
+  auto *DidItSI = ThenBBI.next<StoreInst>();
+  EXPECT_NE(DidItSI, nullptr);
+  EXPECT_EQ(DidItSI->getValueOperand(),
+            ConstantInt::get(Type::getInt32Ty(Ctx), 1));
+  Value *DidIt = DidItSI->getPointerOperand();
+  // call __kmpc_end_single
+  auto *SingleEndCI = ThenBBI.next<CallInst>();
   EXPECT_NE(SingleEndCI, nullptr);
+  EXPECT_EQ(SingleEndCI->getCalledFunction()->getName(), "__kmpc_end_single");
   EXPECT_EQ(SingleEndCI->arg_size(), 2U);
   EXPECT_TRUE(isa<GlobalVariable>(SingleEndCI->getArgOperand(0)));
   EXPECT_EQ(SingleEndCI->getArgOperand(1), SingleEntryCI->getArgOperand(1));
-
-  CallInst *CopyPrivateCI = nullptr;
-  bool FoundBarrier = false;
-  for (auto &FI : *ExitBB) {
-    Instruction *Cur = &FI;
-    if (auto *CI = dyn_cast<CallInst>(Cur)) {
-      if (CI->getCalledFunction()->getName() == "__kmpc_barrier")
-        FoundBarrier = true;
-      else if (CI->getCalledFunction()->getName() == "__kmpc_copyprivate")
-        CopyPrivateCI = CI;
-    }
-  }
-  EXPECT_FALSE(FoundBarrier);
+  // br ExitBB
+  auto *ExitBBBI = ThenBBI.next<BranchInst>();
+  EXPECT_NE(ExitBBBI, nullptr);
+  EXPECT_TRUE(ExitBBBI->isUnconditional());
+  EXPECT_EQ(ExitBBBI->getOperand(0), ExitBB);
+  EXPECT_FALSE(ThenBBI.hasNext());
+
+  // check ExitBB
+  BBInstIter ExitBBI(ExitBB);
+  // call __kmpc_global_thread_num
+  auto *ThreadNumCI = ExitBBI.next<CallInst>();
+  EXPECT_NE(ThreadNumCI, nullptr);
+  EXPECT_EQ(ThreadNumCI->getCalledFunction()->getName(),
+            "__kmpc_global_thread_num");
+  // load DidIt
+  auto *DidItLI = ExitBBI.next<LoadInst>();
+  EXPECT_NE(DidItLI, nullptr);
+  EXPECT_EQ(DidItLI->getPointerOperand(), DidIt);
+  // call __kmpc_copyprivate
+  auto *CopyPrivateCI = ExitBBI.next<CallInst>();
   EXPECT_NE(CopyPrivateCI, nullptr);
   EXPECT_EQ(CopyPrivateCI->arg_size(), 6U);
   EXPECT_TRUE(isa<AllocaInst>(CopyPrivateCI->getArgOperand(3)));
@@ -3571,8 +3604,9 @@ TEST_F(OpenMPIRBuilderTest, SingleDirectiveCopyPrivate) {
   EXPECT_TRUE(isa<Function>(CopyPrivateCI->getArgOperand(4)));
   EXPECT_EQ(CopyPrivateCI->getArgOperand(4), CopyFunc);
   EXPECT_TRUE(isa<LoadInst>(CopyPrivateCI->getArgOperand(5)));
-  LoadInst *DidItLI = cast<LoadInst>(CopyPrivateCI->getArgOperand(5));
+  DidItLI = cast<LoadInst>(CopyPrivateCI->getArgOperand(5));
   EXPECT_EQ(DidItLI->getOperand(0), DidIt);
+  EXPECT_FALSE(ExitBBI.hasNext());
 }
 
 TEST_F(OpenMPIRBuilderTest, OMPAtomicReadFlt) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 43b6155cb72a3b..4aeff47280dbc4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -669,13 +669,9 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
     llvmCPFuncs.push_back(
         moduleTranslation.lookupFunction(llvmFuncOp.getName()));
   }
-  llvm::Value *didIt = nullptr;
-  if (!llvmCPVars.empty())
-    didIt = builder.CreateAlloca(llvm::Type::getInt32Ty(builder.getContext()));
 
   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createSingle(
-      ompLoc, bodyCB, finiCB, singleOp.getNowait(), didIt, llvmCPVars,
-      llvmCPFuncs));
+      ompLoc, bodyCB, finiCB, singleOp.getNowait(), llvmCPVars, llvmCPFuncs));
   return bodyGenStatus;
 }
 



More information about the Mlir-commits mailing list