[llvm] [OpenMPIRBuilder] Add support for target workshare loops (PR #73360)
Dominik Adamski via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 30 02:52:24 PST 2023
================
@@ -2228,9 +2228,73 @@ TEST_F(OpenMPIRBuilderTest, UnrollLoopHeuristic) {
EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.unroll.enable"));
}
+TEST_F(OpenMPIRBuilderTest, StaticWorkshareLoopTarget) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ std::string oldDLStr = M->getDataLayoutStr();
+ M->setDataLayout(
+ "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:"
+ "256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:"
+ "256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8");
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.Config.IsTargetDevice = true;
+ OMPBuilder.initialize();
+ IRBuilder<> Builder(BB);
+ OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+ InsertPointTy AllocaIP = Builder.saveIP();
+
+ Type *LCTy = Type::getInt32Ty(Ctx);
+ Value *StartVal = ConstantInt::get(LCTy, 10);
+ Value *StopVal = ConstantInt::get(LCTy, 52);
+ Value *StepVal = ConstantInt::get(LCTy, 2);
+ auto LoopBodyGen = [&](InsertPointTy, Value *) {};
+
+ CanonicalLoopInfo *CLI = OMPBuilder.createCanonicalLoop(
+ Loc, LoopBodyGen, StartVal, StopVal, StepVal, false, false);
+ BasicBlock *Preheader = CLI->getPreheader();
+ Value *TripCount = CLI->getTripCount();
+
+ Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
+
+ IRBuilder<>::InsertPoint AfterIP = OMPBuilder.applyWorkshareLoop(
+ DL, CLI, AllocaIP, true, OMP_SCHEDULE_Static, nullptr, false, false,
+ false, false, OpenMPIRBuilder::WorksharingLoopType::ForStaticLoop);
+ Builder.restoreIP(AfterIP);
+ Builder.CreateRetVoid();
+
+ OMPBuilder.finalize();
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+
+ CallInst *WorkshareLoopRuntimeCall = nullptr;
+ for (auto Inst = Preheader->begin(); Inst != Preheader->end(); ++Inst) {
+ CallInst *Call = dyn_cast<CallInst>(Inst);
+ if (Call) {
+ if (Call->getCalledFunction()) {
+ if (Call->getCalledFunction()->getName() ==
+ "__kmpc_for_static_loop_4u") {
----------------
DominikAdamski wrote:
Done. Made early exits.
https://github.com/llvm/llvm-project/pull/73360
More information about the llvm-commits
mailing list