[llvm] 35fc666 - [OpenMP][IRBuilder] Add support for taskgroup
Shraiysh Vaishay via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 18 22:19:47 PDT 2022
Author: Shraiysh Vaishay
Date: 2022-07-19T10:49:34+05:30
New Revision: 35fc666877e04321129e8d701f0e6f4f28fb8848
URL: https://github.com/llvm/llvm-project/commit/35fc666877e04321129e8d701f0e6f4f28fb8848
DIFF: https://github.com/llvm/llvm-project/commit/35fc666877e04321129e8d701f0e6f4f28fb8848.diff
LOG: [OpenMP][IRBuilder] Add support for taskgroup
This patch adds support for generating taskgroup construct.
Reviewed By: Meinersbur
Differential Revision: https://reviews.llvm.org/D128203
Added:
Modified:
llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 3dfcabffb58a..e4f2fcc649fc 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -630,6 +630,15 @@ class OpenMPIRBuilder {
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
bool Tied = true, Value *Final = nullptr);
+ /// Generator for the taskgroup construct
+ ///
+ /// \param Loc The location where the taskgroup construct was encountered.
+ /// \param AllocaIP The insertion point to be used for alloca instructions.
+ /// \param BodyGenCB Callback that will generate the region code.
+ InsertPointTy createTaskgroup(const LocationDescription &Loc,
+ InsertPointTy AllocaIP,
+ BodyGenCallbackTy BodyGenCB);
+
/// Functions used to generate reductions. Such functions take two Values
/// representing LHS and RHS of the reduction, respectively, and a reference
/// to the value that is updated to refer to the reduction result.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 574d9174bebf..cee4cddab5e8 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1453,7 +1453,36 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
BodyGenCB(TaskAllocaIP, TaskBodyIP);
- Builder.SetInsertPoint(TaskExitBB);
+ Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
+
+ return Builder.saveIP();
+}
+
+OpenMPIRBuilder::InsertPointTy
+OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
+ InsertPointTy AllocaIP,
+ BodyGenCallbackTy BodyGenCB) {
+ if (!updateToLocation(Loc))
+ return InsertPointTy();
+
+ uint32_t SrcLocStrSize;
+ Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+ Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+ Value *ThreadID = getOrCreateThreadID(Ident);
+
+ // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
+ Function *TaskgroupFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup);
+ Builder.CreateCall(TaskgroupFn, {Ident, ThreadID});
+
+ BasicBlock *TaskgroupExitBB = splitBB(Builder, true, "taskgroup.exit");
+ BodyGenCB(AllocaIP, Builder.saveIP());
+
+ Builder.SetInsertPoint(TaskgroupExitBB);
+ // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
+ Function *EndTaskgroupFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup);
+ Builder.CreateCall(EndTaskgroupFn, {Ident, ThreadID});
return Builder.saveIP();
}
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index d48dd1eaa81c..1026183c3238 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -12,10 +12,12 @@
#include "llvm/IR/DIBuilder.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "gtest/gtest.h"
@@ -4918,4 +4920,173 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
EXPECT_FALSE(verifyModule(*M, &errs()));
}
+TEST_F(OpenMPIRBuilderTest, CreateTaskgroup) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+
+ AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty());
+ AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty());
+ Value *Val128 =
+ Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "bodygen.load");
+ Instruction *ThenTerm, *ElseTerm;
+
+ Value *InternalStoreInst, *InternalLoad32, *InternalLoad128, *InternalIfCmp;
+
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+ Builder.restoreIP(AllocaIP);
+ AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr,
+ "bodygen.alloca128");
+
+ Builder.restoreIP(CodeGenIP);
+ // Loading and storing captured pointer and values
+ InternalStoreInst = Builder.CreateStore(Val128, Local128);
+ InternalLoad32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32,
+ "bodygen.load32");
+
+ InternalLoad128 = Builder.CreateLoad(Local128->getAllocatedType(), Local128,
+ "bodygen.local.load128");
+ InternalIfCmp = Builder.CreateICmpNE(
+ InternalLoad32,
+ Builder.CreateTrunc(InternalLoad128, InternalLoad32->getType()));
+ SplitBlockAndInsertIfThenElse(InternalIfCmp,
+ CodeGenIP.getBlock()->getTerminator(),
+ &ThenTerm, &ElseTerm);
+ };
+
+ BasicBlock *AllocaBB = Builder.GetInsertBlock();
+ BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+ OpenMPIRBuilder::LocationDescription Loc(
+ InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
+ Builder.restoreIP(OMPBuilder.createTaskgroup(
+ Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()),
+ BodyGenCB));
+ OMPBuilder.finalize();
+ Builder.CreateRetVoid();
+
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+
+ CallInst *TaskgroupCall = dyn_cast<CallInst>(
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup)
+ ->user_back());
+ ASSERT_NE(TaskgroupCall, nullptr);
+ CallInst *EndTaskgroupCall = dyn_cast<CallInst>(
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup)
+ ->user_back());
+ ASSERT_NE(EndTaskgroupCall, nullptr);
+
+ // Verify the Ident argument
+ GlobalVariable *Ident = cast<GlobalVariable>(TaskgroupCall->getArgOperand(0));
+ ASSERT_NE(Ident, nullptr);
+ EXPECT_TRUE(Ident->hasInitializer());
+ Constant *Initializer = Ident->getInitializer();
+ GlobalVariable *SrcStrGlob =
+ cast<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts());
+ ASSERT_NE(SrcStrGlob, nullptr);
+ ConstantDataArray *SrcSrc =
+ dyn_cast<ConstantDataArray>(SrcStrGlob->getInitializer());
+ ASSERT_NE(SrcSrc, nullptr);
+
+ // Verify the num_threads argument.
+ CallInst *GTID = dyn_cast<CallInst>(TaskgroupCall->getArgOperand(1));
+ ASSERT_NE(GTID, nullptr);
+ EXPECT_EQ(GTID->arg_size(), 1U);
+ EXPECT_EQ(GTID->getCalledFunction(), OMPBuilder.getOrCreateRuntimeFunctionPtr(
+ OMPRTL___kmpc_global_thread_num));
+
+ // Checking the general structure of the IR generated is same as expected.
+ Instruction *GeneratedStoreInst = TaskgroupCall->getNextNonDebugInstruction();
+ EXPECT_EQ(GeneratedStoreInst, InternalStoreInst);
+ Instruction *GeneratedLoad32 =
+ GeneratedStoreInst->getNextNonDebugInstruction();
+ EXPECT_EQ(GeneratedLoad32, InternalLoad32);
+ Instruction *GeneratedLoad128 = GeneratedLoad32->getNextNonDebugInstruction();
+ EXPECT_EQ(GeneratedLoad128, InternalLoad128);
+
+ // Checking the ordering because of the if statements and that
+ // `__kmp_end_taskgroup` call is after the if branching.
+ BasicBlock *RefOrder[] = {TaskgroupCall->getParent(), ThenTerm->getParent(),
+ ThenTerm->getSuccessor(0),
+ EndTaskgroupCall->getParent(),
+ ElseTerm->getParent()};
+ verifyDFSOrder(F, RefOrder);
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTaskgroupWithTasks) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+ Builder.restoreIP(AllocaIP);
+ AllocaInst *Alloca32 =
+ Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, "bodygen.alloca32");
+ AllocaInst *Alloca64 =
+ Builder.CreateAlloca(Builder.getInt64Ty(), nullptr, "bodygen.alloca64");
+ Builder.restoreIP(CodeGenIP);
+ auto TaskBodyGenCB1 = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ LoadInst *LoadValue =
+ Builder.CreateLoad(Alloca64->getAllocatedType(), Alloca64);
+ Value *AddInst = Builder.CreateAdd(LoadValue, Builder.getInt64(64));
+ Builder.CreateStore(AddInst, Alloca64);
+ };
+ OpenMPIRBuilder::LocationDescription Loc(Builder.saveIP(), DL);
+ Builder.restoreIP(OMPBuilder.createTask(Loc, AllocaIP, TaskBodyGenCB1));
+
+ auto TaskBodyGenCB2 = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ LoadInst *LoadValue =
+ Builder.CreateLoad(Alloca32->getAllocatedType(), Alloca32);
+ Value *AddInst = Builder.CreateAdd(LoadValue, Builder.getInt32(32));
+ Builder.CreateStore(AddInst, Alloca32);
+ };
+ OpenMPIRBuilder::LocationDescription Loc2(Builder.saveIP(), DL);
+ Builder.restoreIP(OMPBuilder.createTask(Loc2, AllocaIP, TaskBodyGenCB2));
+ };
+
+ BasicBlock *AllocaBB = Builder.GetInsertBlock();
+ BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+ OpenMPIRBuilder::LocationDescription Loc(
+ InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
+ Builder.restoreIP(OMPBuilder.createTaskgroup(
+ Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()),
+ BodyGenCB));
+ OMPBuilder.finalize();
+ Builder.CreateRetVoid();
+
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+
+ CallInst *TaskgroupCall = dyn_cast<CallInst>(
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup)
+ ->user_back());
+ ASSERT_NE(TaskgroupCall, nullptr);
+ CallInst *EndTaskgroupCall = dyn_cast<CallInst>(
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup)
+ ->user_back());
+ ASSERT_NE(EndTaskgroupCall, nullptr);
+
+ Function *TaskAllocFn =
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
+ ASSERT_EQ(TaskAllocFn->getNumUses(), 2);
+
+ CallInst *FirstTaskAllocCall =
+ dyn_cast_or_null<CallInst>(*TaskAllocFn->users().begin());
+ CallInst *SecondTaskAllocCall =
+ dyn_cast_or_null<CallInst>(*TaskAllocFn->users().begin()++);
+ ASSERT_NE(FirstTaskAllocCall, nullptr);
+ ASSERT_NE(SecondTaskAllocCall, nullptr);
+
+ // Verify that the tasks have been generated in order and inside taskgroup
+ // construct.
+ BasicBlock *RefOrder[] = {
+ TaskgroupCall->getParent(), FirstTaskAllocCall->getParent(),
+ SecondTaskAllocCall->getParent(), EndTaskgroupCall->getParent()};
+ verifyDFSOrder(F, RefOrder);
+}
+
} // namespace
More information about the llvm-commits
mailing list