[llvm] 000c6a5 - [OpenMP] Use the OpenMPIRBuilder for `omp cancel`
Johannes Doerfert via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 30 11:58:06 PST 2019
Author: Johannes Doerfert
Date: 2019-12-30T13:57:13-06:00
New Revision: 000c6a5038bc654946b4348e586d685077b06943
URL: https://github.com/llvm/llvm-project/commit/000c6a5038bc654946b4348e586d685077b06943
DIFF: https://github.com/llvm/llvm-project/commit/000c6a5038bc654946b4348e586d685077b06943.diff
LOG: [OpenMP] Use the OpenMPIRBuilder for `omp cancel`
An `omp cancel parallel` needs to be emitted by the OpenMPIRBuilder if
the `parallel` was emitted by the OpenMPIRBuilder. This patch makes
this possible. The cancel logic is shared with the cancel barriers.
Testing is done via unit tests and the clang cancel_codegen.cpp file
once D70290 lands.
Reviewed By: JonChesterfield
Differential Revision: https://reviews.llvm.org/D71948
Added:
Modified:
llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 880add6b9bba..6568bd89068c 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -139,6 +139,17 @@ class OpenMPIRBuilder {
bool ForceSimpleCall = false,
bool CheckCancelFlag = true);
+ /// Generator for '#omp cancel'
+ ///
+ /// \param Loc The location where the directive was encountered.
+ /// \param IfCondition The evaluated 'if' clause expression, if any.
+ /// \param CanceledDirective The kind of directive that is cancled.
+ ///
+ /// \returns The insertion point after the barrier.
+ InsertPointTy CreateCancel(const LocationDescription &Loc,
+ Value *IfCondition,
+ omp::Directive CanceledDirective);
+
/// Generator for '#omp parallel'
///
/// \param Loc The insert and source location description.
@@ -183,6 +194,13 @@ class OpenMPIRBuilder {
Value *getOrCreateIdent(Constant *SrcLocStr,
omp::IdentFlag Flags = omp::IdentFlag(0));
+ /// Generate control flow and cleanup for cancellation.
+ ///
+ /// \param CancelFlag Flag indicating if the cancellation is performed.
+ /// \param CanceledDirective The kind of directive that is cancled.
+ void emitCancelationCheckImpl(Value *CancelFlag,
+ omp::Directive CanceledDirective);
+
/// Generate a barrier runtime call.
///
/// \param Loc The location at which the request originated and is fulfilled.
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
index 51644542848a..80707321f82f 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -164,11 +164,13 @@ __OMP_FUNCTION_TYPE(ParallelTask, true, Void, Int32Ptr, Int32Ptr)
OMP_RTL(OMPRTL_##Name, #Name, IsVarArg, ReturnType, __VA_ARGS__)
__OMP_RTL(__kmpc_barrier, false, Void, IdentPtr, Int32)
+__OMP_RTL(__kmpc_cancel, false, Int32, IdentPtr, Int32, Int32)
__OMP_RTL(__kmpc_cancel_barrier, false, Int32, IdentPtr, Int32)
__OMP_RTL(__kmpc_global_thread_num, false, Int32, IdentPtr)
__OMP_RTL(__kmpc_fork_call, true, Void, IdentPtr, Int32, ParallelTaskPtr)
-__OMP_RTL(__kmpc_push_num_threads, false, Void, IdentPtr, Int32, /* Int */Int32)
-__OMP_RTL(__kmpc_push_proc_bind, false, Void, IdentPtr, Int32, /* Int */Int32)
+__OMP_RTL(__kmpc_push_num_threads, false, Void, IdentPtr, Int32,
+ /* Int */ Int32)
+__OMP_RTL(__kmpc_push_proc_bind, false, Void, IdentPtr, Int32, /* Int */ Int32)
__OMP_RTL(__kmpc_serialized_parallel, false, Void, IdentPtr, Int32)
__OMP_RTL(__kmpc_end_serialized_parallel, false, Void, IdentPtr, Int32)
@@ -240,6 +242,26 @@ __OMP_IDENT_FLAG(BARRIER_IMPL_WORKSHARE, 0x01C0)
///}
+/// KMP cancel kind
+///
+///{
+
+#ifndef OMP_CANCEL_KIND
+#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value)
+#endif
+
+#define __OMP_CANCEL_KIND(Name, Value) \
+ OMP_CANCEL_KIND(OMP_CANCEL_KIND_##Name, #Name, OMPD_##Name, Value)
+
+__OMP_CANCEL_KIND(parallel, 1)
+__OMP_CANCEL_KIND(for, 2)
+__OMP_CANCEL_KIND(sections, 3)
+__OMP_CANCEL_KIND(taskgroup, 4)
+
+#undef __OMP_CANCEL_KIND
+#undef OMP_CANCEL_KIND
+
+///}
/// Proc bind kinds
///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 9597eaa3f3c6..57db40775c05 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -216,41 +216,90 @@ OpenMPIRBuilder::emitBarrierImpl(const LocationDescription &Loc, Directive Kind,
: OMPRTL___kmpc_barrier),
Args);
- if (UseCancelBarrier && CheckCancelFlag) {
- // For a cancel barrier we create two new blocks.
- BasicBlock *BB = Builder.GetInsertBlock();
- BasicBlock *NonCancellationBlock;
- if (Builder.GetInsertPoint() == BB->end()) {
- // TODO: This branch will not be needed once we moved to the
- // OpenMPIRBuilder codegen completely.
- NonCancellationBlock = BasicBlock::Create(
- BB->getContext(), BB->getName() + ".cont", BB->getParent());
- } else {
- NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
- BB->getTerminator()->eraseFromParent();
- Builder.SetInsertPoint(BB);
- }
- BasicBlock *CancellationBlock = BasicBlock::Create(
- BB->getContext(), BB->getName() + ".cncl", BB->getParent());
-
- // Jump to them based on the return value.
- Value *Cmp = Builder.CreateIsNull(Result);
- Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
- /* TODO weight */ nullptr, nullptr);
-
- // From the cancellation block we finalize all variables and go to the
- // post finalization block that is known to the FiniCB callback.
- Builder.SetInsertPoint(CancellationBlock);
- auto &FI = FinalizationStack.back();
- FI.FiniCB(Builder.saveIP());
-
- // The continuation block is where code generation continues.
- Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
+ if (UseCancelBarrier && CheckCancelFlag)
+ emitCancelationCheckImpl(Result, OMPD_parallel);
+
+ return Builder.saveIP();
+}
+
+OpenMPIRBuilder::InsertPointTy
+OpenMPIRBuilder::CreateCancel(const LocationDescription &Loc,
+ Value *IfCondition,
+ omp::Directive CanceledDirective) {
+ if (!updateToLocation(Loc))
+ return Loc.IP;
+
+ // LLVM utilities like blocks with terminators.
+ auto *UI = Builder.CreateUnreachable();
+
+ Instruction *ThenTI = UI, *ElseTI = nullptr;
+ if (IfCondition)
+ SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
+ Builder.SetInsertPoint(ThenTI);
+
+ Value *CancelKind = nullptr;
+ switch (CanceledDirective) {
+#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
+ case DirectiveEnum: \
+ CancelKind = Builder.getInt32(Value); \
+ break;
+#include "llvm/Frontend/OpenMP/OMPKinds.def"
+ default:
+ llvm_unreachable("Unknown cancel kind!");
}
+ Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
+ Value *Ident = getOrCreateIdent(SrcLocStr);
+ Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
+ Value *Result = Builder.CreateCall(
+ getOrCreateRuntimeFunction(OMPRTL___kmpc_cancel), Args);
+
+ // The actual cancel logic is shared with others, e.g., cancel_barriers.
+ emitCancelationCheckImpl(Result, CanceledDirective);
+
+ // Update the insertion point and remove the terminator we introduced.
+ Builder.SetInsertPoint(UI->getParent());
+ UI->eraseFromParent();
+
return Builder.saveIP();
}
+void OpenMPIRBuilder::emitCancelationCheckImpl(
+ Value *CancelFlag, omp::Directive CanceledDirective) {
+ assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
+ "Unexpected cancellation!");
+
+ // For a cancel barrier we create two new blocks.
+ BasicBlock *BB = Builder.GetInsertBlock();
+ BasicBlock *NonCancellationBlock;
+ if (Builder.GetInsertPoint() == BB->end()) {
+ // TODO: This branch will not be needed once we moved to the
+ // OpenMPIRBuilder codegen completely.
+ NonCancellationBlock = BasicBlock::Create(
+ BB->getContext(), BB->getName() + ".cont", BB->getParent());
+ } else {
+ NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
+ BB->getTerminator()->eraseFromParent();
+ Builder.SetInsertPoint(BB);
+ }
+ BasicBlock *CancellationBlock = BasicBlock::Create(
+ BB->getContext(), BB->getName() + ".cncl", BB->getParent());
+
+ // Jump to them based on the return value.
+ Value *Cmp = Builder.CreateIsNull(CancelFlag);
+ Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
+ /* TODO weight */ nullptr, nullptr);
+
+ // From the cancellation block we finalize all variables and go to the
+ // post finalization block that is known to the FiniCB callback.
+ Builder.SetInsertPoint(CancellationBlock);
+ auto &FI = FinalizationStack.back();
+ FI.FiniCB(Builder.saveIP());
+
+ // The continuation block is where code generation continues.
+ Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
+}
+
IRBuilder<>::InsertPoint OpenMPIRBuilder::CreateParallel(
const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
PrivatizeCallbackTy PrivCB, FinalizeCallbackTy FiniCB, Value *IfCondition,
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 7d2d0b9fe0a2..e777149b30ee 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -99,6 +99,122 @@ TEST_F(OpenMPIRBuilderTest, CreateBarrier) {
EXPECT_FALSE(verifyModule(*M));
}
+TEST_F(OpenMPIRBuilderTest, CreateCancel) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+
+ BasicBlock *CBB = BasicBlock::Create(Ctx, "", F);
+ new UnreachableInst(Ctx, CBB);
+ auto FiniCB = [&](InsertPointTy IP) {
+ ASSERT_NE(IP.getBlock(), nullptr);
+ ASSERT_EQ(IP.getBlock()->end(), IP.getPoint());
+ BranchInst::Create(CBB, IP.getBlock());
+ };
+ OMPBuilder.pushFinalizationCB({FiniCB, OMPD_parallel, true});
+
+ IRBuilder<> Builder(BB);
+
+ OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP()});
+ auto NewIP = OMPBuilder.CreateCancel(Loc, nullptr, OMPD_parallel);
+ Builder.restoreIP(NewIP);
+ EXPECT_FALSE(M->global_empty());
+ EXPECT_EQ(M->size(), 3U);
+ EXPECT_EQ(F->size(), 4U);
+ EXPECT_EQ(BB->size(), 4U);
+
+ CallInst *GTID = dyn_cast<CallInst>(&BB->front());
+ EXPECT_NE(GTID, nullptr);
+ EXPECT_EQ(GTID->getNumArgOperands(), 1U);
+ EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num");
+ EXPECT_FALSE(GTID->getCalledFunction()->doesNotAccessMemory());
+ EXPECT_FALSE(GTID->getCalledFunction()->doesNotFreeMemory());
+
+ CallInst *Cancel = dyn_cast<CallInst>(GTID->getNextNode());
+ EXPECT_NE(Cancel, nullptr);
+ EXPECT_EQ(Cancel->getNumArgOperands(), 3U);
+ EXPECT_EQ(Cancel->getCalledFunction()->getName(), "__kmpc_cancel");
+ EXPECT_FALSE(Cancel->getCalledFunction()->doesNotAccessMemory());
+ EXPECT_FALSE(Cancel->getCalledFunction()->doesNotFreeMemory());
+ EXPECT_EQ(Cancel->getNumUses(), 1U);
+ Instruction *CancelBBTI = Cancel->getParent()->getTerminator();
+ EXPECT_EQ(CancelBBTI->getNumSuccessors(), 2U);
+ EXPECT_EQ(CancelBBTI->getSuccessor(0), NewIP.getBlock());
+ EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 1U);
+ EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(),
+ 1U);
+ EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0),
+ CBB);
+
+ EXPECT_EQ(cast<CallInst>(Cancel)->getArgOperand(1), GTID);
+
+ OMPBuilder.popFinalizationCB();
+
+ Builder.CreateUnreachable();
+ EXPECT_FALSE(verifyModule(*M));
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateCancelIfCond) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+
+ BasicBlock *CBB = BasicBlock::Create(Ctx, "", F);
+ new UnreachableInst(Ctx, CBB);
+ auto FiniCB = [&](InsertPointTy IP) {
+ ASSERT_NE(IP.getBlock(), nullptr);
+ ASSERT_EQ(IP.getBlock()->end(), IP.getPoint());
+ BranchInst::Create(CBB, IP.getBlock());
+ };
+ OMPBuilder.pushFinalizationCB({FiniCB, OMPD_parallel, true});
+
+ IRBuilder<> Builder(BB);
+
+ OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP()});
+ auto NewIP = OMPBuilder.CreateCancel(Loc, Builder.getTrue(), OMPD_parallel);
+ Builder.restoreIP(NewIP);
+ EXPECT_FALSE(M->global_empty());
+ EXPECT_EQ(M->size(), 3U);
+ EXPECT_EQ(F->size(), 7U);
+ EXPECT_EQ(BB->size(), 1U);
+ ASSERT_TRUE(isa<BranchInst>(BB->getTerminator()));
+ ASSERT_EQ(BB->getTerminator()->getNumSuccessors(), 2U);
+ BB = BB->getTerminator()->getSuccessor(0);
+ EXPECT_EQ(BB->size(), 4U);
+
+
+ CallInst *GTID = dyn_cast<CallInst>(&BB->front());
+ EXPECT_NE(GTID, nullptr);
+ EXPECT_EQ(GTID->getNumArgOperands(), 1U);
+ EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num");
+ EXPECT_FALSE(GTID->getCalledFunction()->doesNotAccessMemory());
+ EXPECT_FALSE(GTID->getCalledFunction()->doesNotFreeMemory());
+
+ CallInst *Cancel = dyn_cast<CallInst>(GTID->getNextNode());
+ EXPECT_NE(Cancel, nullptr);
+ EXPECT_EQ(Cancel->getNumArgOperands(), 3U);
+ EXPECT_EQ(Cancel->getCalledFunction()->getName(), "__kmpc_cancel");
+ EXPECT_FALSE(Cancel->getCalledFunction()->doesNotAccessMemory());
+ EXPECT_FALSE(Cancel->getCalledFunction()->doesNotFreeMemory());
+ EXPECT_EQ(Cancel->getNumUses(), 1U);
+ Instruction *CancelBBTI = Cancel->getParent()->getTerminator();
+ EXPECT_EQ(CancelBBTI->getNumSuccessors(), 2U);
+ EXPECT_EQ(CancelBBTI->getSuccessor(0)->size(), 1U);
+ EXPECT_EQ(CancelBBTI->getSuccessor(0)->getUniqueSuccessor(), NewIP.getBlock());
+ EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 1U);
+ EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(),
+ 1U);
+ EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0),
+ CBB);
+
+ EXPECT_EQ(cast<CallInst>(Cancel)->getArgOperand(1), GTID);
+
+ OMPBuilder.popFinalizationCB();
+
+ Builder.CreateUnreachable();
+ EXPECT_FALSE(verifyModule(*M));
+}
+
TEST_F(OpenMPIRBuilderTest, CreateCancelBarrier) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
More information about the llvm-commits
mailing list