[llvm] 95eb510 - [OpenMP][IRBuilder] Added if clause to task
via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 22 18:39:52 PDT 2022
Author: Shraiysh Vaishay
Date: 2022-09-23T01:39:41Z
New Revision: 95eb5109afa4c5d26a9432e01eb31dbfc16355c5
URL: https://github.com/llvm/llvm-project/commit/95eb5109afa4c5d26a9432e01eb31dbfc16355c5
DIFF: https://github.com/llvm/llvm-project/commit/95eb5109afa4c5d26a9432e01eb31dbfc16355c5.diff
LOG: [OpenMP][IRBuilder] Added if clause to task
This patch adds support for if clause to task construct in OpenMP
IRBuilder.
Reviewed By: raghavendhra
Differential Revision: https://reviews.llvm.org/D130615
Added:
Modified:
llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 369392df3fb7c..ba63353edd5ff 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -647,9 +647,16 @@ class OpenMPIRBuilder {
/// \param Tied True if the task is tied, false if the task is untied.
/// \param Final i1 value which is `true` if the task is final, `false` if the
/// task is not final.
+ /// \param IfCondition i1 value. If it evaluates to `false`, an undeferred
+ /// task is generated, and the encountering thread must
+ /// suspend the current task region, for which execution
+ /// cannot be resumed until execution of the structured
+ /// block that is associated with the generated task is
+ /// completed.
InsertPointTy createTask(const LocationDescription &Loc,
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
- bool Tied = true, Value *Final = nullptr);
+ bool Tied = true, Value *Final = nullptr,
+ Value *IfCondition = nullptr);
/// Generator for the taskgroup construct
///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 92f6289087e1e..de2ac6e1f9561 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -24,6 +24,7 @@
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/MDBuilder.h"
@@ -1289,7 +1290,7 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
OpenMPIRBuilder::InsertPointTy
OpenMPIRBuilder::createTask(const LocationDescription &Loc,
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
- bool Tied, Value *Final) {
+ bool Tied, Value *Final, Value *IfCondition) {
if (!updateToLocation(Loc))
return InsertPointTy();
@@ -1321,7 +1322,8 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
OI.EntryBB = TaskAllocaBB;
OI.OuterAllocaBB = AllocaIP.getBlock();
OI.ExitBB = TaskExitBB;
- OI.PostOutlineCB = [this, Ident, Tied, Final](Function &OutlinedFn) {
+ OI.PostOutlineCB = [this, Ident, Tied, Final,
+ IfCondition](Function &OutlinedFn) {
// The input IR here looks like the following-
// ```
// func @current_fn() {
@@ -1431,6 +1433,44 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
TaskSize);
}
+ // In the presence of the `if` clause, the following IR is generated:
+ // ...
+ // %data = call @__kmpc_omp_task_alloc(...)
+ // br i1 %if_condition, label %then, label %else
+ // then:
+ // call @__kmpc_omp_task(...)
+ // br label %exit
+ // else:
+ // call @__kmpc_omp_task_begin_if0(...)
+ // call @wrapper_fn(...)
+ // call @__kmpc_omp_task_complete_if0(...)
+ // br label %exit
+ // exit:
+ // ...
+ if (IfCondition) {
+ // `SplitBlockAndInsertIfThenElse` requires the block to have a
+ // terminator.
+ BasicBlock *NewBasicBlock =
+ splitBB(Builder, /*CreateBranch=*/true, "if.end");
+ Instruction *IfTerminator =
+ NewBasicBlock->getSinglePredecessor()->getTerminator();
+ Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
+ Builder.SetInsertPoint(IfTerminator);
+ SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
+ &ElseTI);
+ Builder.SetInsertPoint(ElseTI);
+ Function *TaskBeginFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
+ Function *TaskCompleteFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
+ Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, NewTaskData});
+ if (HasTaskData)
+ Builder.CreateCall(WrapperFunc, {ThreadID, NewTaskData});
+ else
+ Builder.CreateCall(WrapperFunc, {ThreadID});
+ Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, NewTaskData});
+ Builder.SetInsertPoint(ThenTI);
+ }
// Emit the @__kmpc_omp_task runtime call to spawn the task
Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index aa120c1a08878..92a118be4ff60 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5040,6 +5040,71 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
EXPECT_FALSE(verifyModule(*M, &errs()));
}
+TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
+ IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
+ BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+ Builder.SetInsertPoint(BodyBB);
+ Value *IfCondition = Builder.CreateICmp(
+ CmpInst::Predicate::ICMP_EQ, F->getArg(0),
+ ConstantInt::get(Type::getInt32Ty(M->getContext()), 0U));
+ OpenMPIRBuilder::LocationDescription Loc(Builder.saveIP(), DL);
+ Builder.restoreIP(OMPBuilder.createTask(Loc, AllocaIP, BodyGenCB,
+ /*Tied=*/false, /*Final=*/nullptr,
+ IfCondition));
+ OMPBuilder.finalize();
+ Builder.CreateRetVoid();
+
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+
+ CallInst *TaskAllocCall = dyn_cast<CallInst>(
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
+ ->user_back());
+ ASSERT_NE(TaskAllocCall, nullptr);
+
+ // Check the branching is based on the if condition argument.
+ BranchInst *IfConditionBranchInst =
+ dyn_cast<BranchInst>(TaskAllocCall->getParent()->getTerminator());
+ ASSERT_NE(IfConditionBranchInst, nullptr);
+ ASSERT_TRUE(IfConditionBranchInst->isConditional());
+ EXPECT_EQ(IfConditionBranchInst->getCondition(), IfCondition);
+
+ // Check that the `__kmpc_omp_task` executes only in the then branch.
+ CallInst *TaskCall = dyn_cast<CallInst>(
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task)
+ ->user_back());
+ ASSERT_NE(TaskCall, nullptr);
+ EXPECT_EQ(TaskCall->getParent(), IfConditionBranchInst->getSuccessor(0));
+
+ // Check that the OpenMP Runtime Functions specific to `if` clause execute
+ // only in the else branch. Also check that the function call is between the
+ // `__kmpc_omp_task_begin_if0` and `__kmpc_omp_task_complete_if0` calls.
+ CallInst *TaskBeginIfCall = dyn_cast<CallInst>(
+ OMPBuilder
+ .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0)
+ ->user_back());
+ CallInst *TaskCompleteCall = dyn_cast<CallInst>(
+ OMPBuilder
+ .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0)
+ ->user_back());
+ ASSERT_NE(TaskBeginIfCall, nullptr);
+ ASSERT_NE(TaskCompleteCall, nullptr);
+ Function *WrapperFunc =
+ dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
+ ASSERT_NE(WrapperFunc, nullptr);
+ CallInst *WrapperFuncCall = dyn_cast<CallInst>(WrapperFunc->user_back());
+ ASSERT_NE(WrapperFuncCall, nullptr);
+ EXPECT_EQ(TaskBeginIfCall->getParent(),
+ IfConditionBranchInst->getSuccessor(1));
+ EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), WrapperFuncCall);
+ EXPECT_EQ(WrapperFuncCall->getNextNonDebugInstruction(), TaskCompleteCall);
+}
+
TEST_F(OpenMPIRBuilderTest, CreateTaskgroup) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
More information about the llvm-commits
mailing list