[llvm] [OpenMPIRBuilder] Add support for target workshare loops (PR #73360)

Dominik Adamski via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 30 02:52:24 PST 2023


================
@@ -2228,9 +2228,73 @@ TEST_F(OpenMPIRBuilderTest, UnrollLoopHeuristic) {
   EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.unroll.enable"));
 }
 
+TEST_F(OpenMPIRBuilderTest, StaticWorkshareLoopTarget) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  std::string oldDLStr = M->getDataLayoutStr();
+  M->setDataLayout(
+      "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:"
+      "256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:"
+      "256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8");
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.Config.IsTargetDevice = true;
+  OMPBuilder.initialize();
+  IRBuilder<> Builder(BB);
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  InsertPointTy AllocaIP = Builder.saveIP();
+
+  Type *LCTy = Type::getInt32Ty(Ctx);
+  Value *StartVal = ConstantInt::get(LCTy, 10);
+  Value *StopVal = ConstantInt::get(LCTy, 52);
+  Value *StepVal = ConstantInt::get(LCTy, 2);
+  auto LoopBodyGen = [&](InsertPointTy, Value *) {};
+
+  CanonicalLoopInfo *CLI = OMPBuilder.createCanonicalLoop(
+      Loc, LoopBodyGen, StartVal, StopVal, StepVal, false, false);
+  BasicBlock *Preheader = CLI->getPreheader();
+  Value *TripCount = CLI->getTripCount();
+
+  Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
+
+  IRBuilder<>::InsertPoint AfterIP = OMPBuilder.applyWorkshareLoop(
+      DL, CLI, AllocaIP, true, OMP_SCHEDULE_Static, nullptr, false, false,
+      false, false, OpenMPIRBuilder::WorksharingLoopType::ForStaticLoop);
+  Builder.restoreIP(AfterIP);
+  Builder.CreateRetVoid();
+
+  OMPBuilder.finalize();
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+
+  CallInst *WorkshareLoopRuntimeCall = nullptr;
+  for (auto Inst = Preheader->begin(); Inst != Preheader->end(); ++Inst) {
+    CallInst *Call = dyn_cast<CallInst>(Inst);
+    if (Call) {
+      if (Call->getCalledFunction()) {
+        if (Call->getCalledFunction()->getName() ==
+            "__kmpc_for_static_loop_4u") {
----------------
DominikAdamski wrote:

Done. Made early exits.

https://github.com/llvm/llvm-project/pull/73360


More information about the llvm-commits mailing list