[llvm] 000c6a5 - [OpenMP] Use the OpenMPIRBuilder for `omp cancel`

Johannes Doerfert via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 30 11:58:06 PST 2019


Author: Johannes Doerfert
Date: 2019-12-30T13:57:13-06:00
New Revision: 000c6a5038bc654946b4348e586d685077b06943

URL: https://github.com/llvm/llvm-project/commit/000c6a5038bc654946b4348e586d685077b06943
DIFF: https://github.com/llvm/llvm-project/commit/000c6a5038bc654946b4348e586d685077b06943.diff

LOG: [OpenMP] Use the OpenMPIRBuilder for `omp cancel`

An `omp cancel parallel` needs to be emitted by the OpenMPIRBuilder if
the `parallel` was emitted by the OpenMPIRBuilder. This patch makes
this possible. The cancel logic is shared with the cancel barriers.
Testing is done via unit tests and the clang cancel_codegen.cpp file
once D70290 lands.

Reviewed By: JonChesterfield

Differential Revision: https://reviews.llvm.org/D71948

Added: 
    

Modified: 
    llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
    llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
    llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
    llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 880add6b9bba..6568bd89068c 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -139,6 +139,17 @@ class OpenMPIRBuilder {
                               bool ForceSimpleCall = false,
                               bool CheckCancelFlag = true);
 
+  /// Generator for '#omp cancel'
+  ///
+  /// \param Loc The location where the directive was encountered.
+  /// \param IfCondition The evaluated 'if' clause expression, if any.
+  /// \param CanceledDirective The kind of directive that is cancled.
+  ///
+  /// \returns The insertion point after the barrier.
+  InsertPointTy CreateCancel(const LocationDescription &Loc,
+                              Value *IfCondition,
+                              omp::Directive CanceledDirective);
+
   /// Generator for '#omp parallel'
   ///
   /// \param Loc The insert and source location description.
@@ -183,6 +194,13 @@ class OpenMPIRBuilder {
   Value *getOrCreateIdent(Constant *SrcLocStr,
                           omp::IdentFlag Flags = omp::IdentFlag(0));
 
+  /// Generate control flow and cleanup for cancellation.
+  ///
+  /// \param CancelFlag Flag indicating if the cancellation is performed.
+  /// \param CanceledDirective The kind of directive that is cancled.
+  void emitCancelationCheckImpl(Value *CancelFlag,
+                                omp::Directive CanceledDirective);
+
   /// Generate a barrier runtime call.
   ///
   /// \param Loc The location at which the request originated and is fulfilled.

diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
index 51644542848a..80707321f82f 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -164,11 +164,13 @@ __OMP_FUNCTION_TYPE(ParallelTask, true, Void, Int32Ptr, Int32Ptr)
   OMP_RTL(OMPRTL_##Name, #Name, IsVarArg, ReturnType, __VA_ARGS__)
 
 __OMP_RTL(__kmpc_barrier, false, Void, IdentPtr, Int32)
+__OMP_RTL(__kmpc_cancel, false, Int32, IdentPtr, Int32, Int32)
 __OMP_RTL(__kmpc_cancel_barrier, false, Int32, IdentPtr, Int32)
 __OMP_RTL(__kmpc_global_thread_num, false, Int32, IdentPtr)
 __OMP_RTL(__kmpc_fork_call, true, Void, IdentPtr, Int32, ParallelTaskPtr)
-__OMP_RTL(__kmpc_push_num_threads, false, Void, IdentPtr, Int32, /* Int */Int32)
-__OMP_RTL(__kmpc_push_proc_bind, false, Void, IdentPtr, Int32, /* Int */Int32)
+__OMP_RTL(__kmpc_push_num_threads, false, Void, IdentPtr, Int32,
+          /* Int */ Int32)
+__OMP_RTL(__kmpc_push_proc_bind, false, Void, IdentPtr, Int32, /* Int */ Int32)
 __OMP_RTL(__kmpc_serialized_parallel, false, Void, IdentPtr, Int32)
 __OMP_RTL(__kmpc_end_serialized_parallel, false, Void, IdentPtr, Int32)
 
@@ -240,6 +242,26 @@ __OMP_IDENT_FLAG(BARRIER_IMPL_WORKSHARE, 0x01C0)
 
 ///}
 
+/// KMP cancel kind
+///
+///{
+
+#ifndef OMP_CANCEL_KIND
+#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value)
+#endif
+
+#define __OMP_CANCEL_KIND(Name, Value)                                         \
+  OMP_CANCEL_KIND(OMP_CANCEL_KIND_##Name, #Name, OMPD_##Name, Value)
+
+__OMP_CANCEL_KIND(parallel, 1)
+__OMP_CANCEL_KIND(for, 2)
+__OMP_CANCEL_KIND(sections, 3)
+__OMP_CANCEL_KIND(taskgroup, 4)
+
+#undef __OMP_CANCEL_KIND
+#undef OMP_CANCEL_KIND
+
+///}
 
 /// Proc bind kinds
 ///

diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 9597eaa3f3c6..57db40775c05 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -216,41 +216,90 @@ OpenMPIRBuilder::emitBarrierImpl(const LocationDescription &Loc, Directive Kind,
                                                   : OMPRTL___kmpc_barrier),
       Args);
 
-  if (UseCancelBarrier && CheckCancelFlag) {
-    // For a cancel barrier we create two new blocks.
-    BasicBlock *BB = Builder.GetInsertBlock();
-    BasicBlock *NonCancellationBlock;
-    if (Builder.GetInsertPoint() == BB->end()) {
-      // TODO: This branch will not be needed once we moved to the
-      // OpenMPIRBuilder codegen completely.
-      NonCancellationBlock = BasicBlock::Create(
-          BB->getContext(), BB->getName() + ".cont", BB->getParent());
-    } else {
-      NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
-      BB->getTerminator()->eraseFromParent();
-      Builder.SetInsertPoint(BB);
-    }
-    BasicBlock *CancellationBlock = BasicBlock::Create(
-        BB->getContext(), BB->getName() + ".cncl", BB->getParent());
-
-    // Jump to them based on the return value.
-    Value *Cmp = Builder.CreateIsNull(Result);
-    Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
-                         /* TODO weight */ nullptr, nullptr);
-
-    // From the cancellation block we finalize all variables and go to the
-    // post finalization block that is known to the FiniCB callback.
-    Builder.SetInsertPoint(CancellationBlock);
-    auto &FI = FinalizationStack.back();
-    FI.FiniCB(Builder.saveIP());
-
-    // The continuation block is where code generation continues.
-    Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
+  if (UseCancelBarrier && CheckCancelFlag)
+    emitCancelationCheckImpl(Result, OMPD_parallel);
+
+  return Builder.saveIP();
+}
+
+OpenMPIRBuilder::InsertPointTy
+OpenMPIRBuilder::CreateCancel(const LocationDescription &Loc,
+                              Value *IfCondition,
+                              omp::Directive CanceledDirective) {
+  if (!updateToLocation(Loc))
+    return Loc.IP;
+
+  // LLVM utilities like blocks with terminators.
+  auto *UI = Builder.CreateUnreachable();
+
+  Instruction *ThenTI = UI, *ElseTI = nullptr;
+  if (IfCondition)
+    SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
+  Builder.SetInsertPoint(ThenTI);
+
+  Value *CancelKind = nullptr;
+  switch (CanceledDirective) {
+#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value)                       \
+  case DirectiveEnum:                                                          \
+    CancelKind = Builder.getInt32(Value);                                      \
+    break;
+#include "llvm/Frontend/OpenMP/OMPKinds.def"
+  default:
+    llvm_unreachable("Unknown cancel kind!");
   }
 
+  Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
+  Value *Ident = getOrCreateIdent(SrcLocStr);
+  Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
+  Value *Result = Builder.CreateCall(
+      getOrCreateRuntimeFunction(OMPRTL___kmpc_cancel), Args);
+
+  // The actual cancel logic is shared with others, e.g., cancel_barriers.
+  emitCancelationCheckImpl(Result, CanceledDirective);
+
+  // Update the insertion point and remove the terminator we introduced.
+  Builder.SetInsertPoint(UI->getParent());
+  UI->eraseFromParent();
+
   return Builder.saveIP();
 }
 
+void OpenMPIRBuilder::emitCancelationCheckImpl(
+    Value *CancelFlag, omp::Directive CanceledDirective) {
+  assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
+         "Unexpected cancellation!");
+
+  // For a cancel barrier we create two new blocks.
+  BasicBlock *BB = Builder.GetInsertBlock();
+  BasicBlock *NonCancellationBlock;
+  if (Builder.GetInsertPoint() == BB->end()) {
+    // TODO: This branch will not be needed once we moved to the
+    // OpenMPIRBuilder codegen completely.
+    NonCancellationBlock = BasicBlock::Create(
+        BB->getContext(), BB->getName() + ".cont", BB->getParent());
+  } else {
+    NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
+    BB->getTerminator()->eraseFromParent();
+    Builder.SetInsertPoint(BB);
+  }
+  BasicBlock *CancellationBlock = BasicBlock::Create(
+      BB->getContext(), BB->getName() + ".cncl", BB->getParent());
+
+  // Jump to them based on the return value.
+  Value *Cmp = Builder.CreateIsNull(CancelFlag);
+  Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
+                       /* TODO weight */ nullptr, nullptr);
+
+  // From the cancellation block we finalize all variables and go to the
+  // post finalization block that is known to the FiniCB callback.
+  Builder.SetInsertPoint(CancellationBlock);
+  auto &FI = FinalizationStack.back();
+  FI.FiniCB(Builder.saveIP());
+
+  // The continuation block is where code generation continues.
+  Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
+}
+
 IRBuilder<>::InsertPoint OpenMPIRBuilder::CreateParallel(
     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
     PrivatizeCallbackTy PrivCB, FinalizeCallbackTy FiniCB, Value *IfCondition,

diff  --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 7d2d0b9fe0a2..e777149b30ee 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -99,6 +99,122 @@ TEST_F(OpenMPIRBuilderTest, CreateBarrier) {
   EXPECT_FALSE(verifyModule(*M));
 }
 
+TEST_F(OpenMPIRBuilderTest, CreateCancel) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+
+  BasicBlock *CBB = BasicBlock::Create(Ctx, "", F);
+  new UnreachableInst(Ctx, CBB);
+  auto FiniCB = [&](InsertPointTy IP) {
+    ASSERT_NE(IP.getBlock(), nullptr);
+    ASSERT_EQ(IP.getBlock()->end(), IP.getPoint());
+    BranchInst::Create(CBB, IP.getBlock());
+  };
+  OMPBuilder.pushFinalizationCB({FiniCB, OMPD_parallel, true});
+
+  IRBuilder<> Builder(BB);
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP()});
+  auto NewIP = OMPBuilder.CreateCancel(Loc, nullptr, OMPD_parallel);
+  Builder.restoreIP(NewIP);
+  EXPECT_FALSE(M->global_empty());
+  EXPECT_EQ(M->size(), 3U);
+  EXPECT_EQ(F->size(), 4U);
+  EXPECT_EQ(BB->size(), 4U);
+
+  CallInst *GTID = dyn_cast<CallInst>(&BB->front());
+  EXPECT_NE(GTID, nullptr);
+  EXPECT_EQ(GTID->getNumArgOperands(), 1U);
+  EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num");
+  EXPECT_FALSE(GTID->getCalledFunction()->doesNotAccessMemory());
+  EXPECT_FALSE(GTID->getCalledFunction()->doesNotFreeMemory());
+
+  CallInst *Cancel = dyn_cast<CallInst>(GTID->getNextNode());
+  EXPECT_NE(Cancel, nullptr);
+  EXPECT_EQ(Cancel->getNumArgOperands(), 3U);
+  EXPECT_EQ(Cancel->getCalledFunction()->getName(), "__kmpc_cancel");
+  EXPECT_FALSE(Cancel->getCalledFunction()->doesNotAccessMemory());
+  EXPECT_FALSE(Cancel->getCalledFunction()->doesNotFreeMemory());
+  EXPECT_EQ(Cancel->getNumUses(), 1U);
+  Instruction *CancelBBTI = Cancel->getParent()->getTerminator();
+  EXPECT_EQ(CancelBBTI->getNumSuccessors(), 2U);
+  EXPECT_EQ(CancelBBTI->getSuccessor(0), NewIP.getBlock());
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 1U);
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(),
+            1U);
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0),
+            CBB);
+
+  EXPECT_EQ(cast<CallInst>(Cancel)->getArgOperand(1), GTID);
+
+  OMPBuilder.popFinalizationCB();
+
+  Builder.CreateUnreachable();
+  EXPECT_FALSE(verifyModule(*M));
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateCancelIfCond) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+
+  BasicBlock *CBB = BasicBlock::Create(Ctx, "", F);
+  new UnreachableInst(Ctx, CBB);
+  auto FiniCB = [&](InsertPointTy IP) {
+    ASSERT_NE(IP.getBlock(), nullptr);
+    ASSERT_EQ(IP.getBlock()->end(), IP.getPoint());
+    BranchInst::Create(CBB, IP.getBlock());
+  };
+  OMPBuilder.pushFinalizationCB({FiniCB, OMPD_parallel, true});
+
+  IRBuilder<> Builder(BB);
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP()});
+  auto NewIP = OMPBuilder.CreateCancel(Loc, Builder.getTrue(), OMPD_parallel);
+  Builder.restoreIP(NewIP);
+  EXPECT_FALSE(M->global_empty());
+  EXPECT_EQ(M->size(), 3U);
+  EXPECT_EQ(F->size(), 7U);
+  EXPECT_EQ(BB->size(), 1U);
+  ASSERT_TRUE(isa<BranchInst>(BB->getTerminator()));
+  ASSERT_EQ(BB->getTerminator()->getNumSuccessors(), 2U);
+  BB = BB->getTerminator()->getSuccessor(0);
+  EXPECT_EQ(BB->size(), 4U);
+
+
+  CallInst *GTID = dyn_cast<CallInst>(&BB->front());
+  EXPECT_NE(GTID, nullptr);
+  EXPECT_EQ(GTID->getNumArgOperands(), 1U);
+  EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num");
+  EXPECT_FALSE(GTID->getCalledFunction()->doesNotAccessMemory());
+  EXPECT_FALSE(GTID->getCalledFunction()->doesNotFreeMemory());
+
+  CallInst *Cancel = dyn_cast<CallInst>(GTID->getNextNode());
+  EXPECT_NE(Cancel, nullptr);
+  EXPECT_EQ(Cancel->getNumArgOperands(), 3U);
+  EXPECT_EQ(Cancel->getCalledFunction()->getName(), "__kmpc_cancel");
+  EXPECT_FALSE(Cancel->getCalledFunction()->doesNotAccessMemory());
+  EXPECT_FALSE(Cancel->getCalledFunction()->doesNotFreeMemory());
+  EXPECT_EQ(Cancel->getNumUses(), 1U);
+  Instruction *CancelBBTI = Cancel->getParent()->getTerminator();
+  EXPECT_EQ(CancelBBTI->getNumSuccessors(), 2U);
+  EXPECT_EQ(CancelBBTI->getSuccessor(0)->size(), 1U);
+  EXPECT_EQ(CancelBBTI->getSuccessor(0)->getUniqueSuccessor(), NewIP.getBlock());
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 1U);
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(),
+            1U);
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0),
+            CBB);
+
+  EXPECT_EQ(cast<CallInst>(Cancel)->getArgOperand(1), GTID);
+
+  OMPBuilder.popFinalizationCB();
+
+  Builder.CreateUnreachable();
+  EXPECT_FALSE(verifyModule(*M));
+}
+
 TEST_F(OpenMPIRBuilderTest, CreateCancelBarrier) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);


        


More information about the llvm-commits mailing list