[llvm] [OpenMPIRBuilder] Add support for target workshare loops (PR #73360)

Johannes Doerfert via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 28 11:14:43 PST 2023


================
@@ -2681,11 +2681,255 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
   return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
 }
 
+// Returns an LLVM function to call for executing an OpenMP static worksharing
+// for loop depending on `type`. Only i32 and i64 are supported by the runtime.
+// Always interpret integers as unsigned similarly to CanonicalLoopInfo.
+static FunctionCallee
+getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
+                            OpenMPIRBuilder::WorksharingLoopType LoopType) {
+  unsigned Bitwidth = Ty->getIntegerBitWidth();
+  Module &M = OMPBuilder->M;
+  switch (LoopType) {
+  case OpenMPIRBuilder::WorksharingLoopType::ForStaticLoop:
+    if (Bitwidth == 32)
+      return OMPBuilder->getOrCreateRuntimeFunction(
+          M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
+    if (Bitwidth == 64)
+      return OMPBuilder->getOrCreateRuntimeFunction(
+          M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
+    break;
+  case OpenMPIRBuilder::WorksharingLoopType::DistributeStaticLoop:
+    if (Bitwidth == 32)
+      return OMPBuilder->getOrCreateRuntimeFunction(
+          M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
+    if (Bitwidth == 64)
+      return OMPBuilder->getOrCreateRuntimeFunction(
+          M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
+    break;
+  case OpenMPIRBuilder::WorksharingLoopType::DistributeForStaticLoop:
+    if (Bitwidth == 32)
+      return OMPBuilder->getOrCreateRuntimeFunction(
+          M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
+    if (Bitwidth == 64)
+      return OMPBuilder->getOrCreateRuntimeFunction(
+          M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
+    break;
+  }
+  if (Bitwidth != 32 && Bitwidth != 64)
+    llvm_unreachable("unknown OpenMP loop iterator bitwidth");
+  return FunctionCallee();
+}
+
+// Inserts a call to proper OpenMP Device RTL function which handles
+// loop worksharing.
+static void createTargetLoopWorkshareCall(
+    OpenMPIRBuilder *OMPBuilder, OpenMPIRBuilder::WorksharingLoopType LoopType,
+    BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
+    Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
+  Type *TripCountTy = TripCount->getType();
+  Module &M = OMPBuilder->M;
+  IRBuilder<> &Builder = OMPBuilder->Builder;
+  FunctionCallee RTLFn =
+      getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
+  SmallVector<Value *, 8> RealArgs;
+  RealArgs.push_back(Ident);
+  /*loop body func*/
+  RealArgs.push_back(Builder.CreateBitCast(&LoopBodyFn, ParallelTaskPtr));
+  /*loop body args*/
+  RealArgs.push_back(LoopBodyArg);
+  /*num of iters*/
+  RealArgs.push_back(TripCount);
+  if (LoopType == OpenMPIRBuilder::WorksharingLoopType::DistributeStaticLoop) {
+    /*block chunk*/ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
+                                           ? Builder.getInt32(0)
+                                           : Builder.getInt64(0));
+    Builder.CreateCall(RTLFn, RealArgs);
+    return;
+  }
+  FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
+      M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
+  Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
+  Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});
+
+  /*num of threads*/ RealArgs.push_back(
+      Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast"));
+  if (LoopType ==
+      OpenMPIRBuilder::WorksharingLoopType::DistributeForStaticLoop) {
+    /*block chunk*/ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
+                                           ? Builder.getInt32(0)
+                                           : Builder.getInt64(0));
+  }
+  /*thread chunk */ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
+                                           ? Builder.getInt32(1)
+                                           : Builder.getInt64(1));
+
+  Builder.CreateCall(RTLFn, RealArgs);
+}
+
+static void
+workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder,
+                            CanonicalLoopInfo *CLI, Value *Ident,
+                            Function &OutlinedFn, Type *ParallelTaskPtr,
+                            const SmallVector<Instruction *, 4> &ToBeDeleted,
+                            OpenMPIRBuilder::WorksharingLoopType LoopType) {
+  IRBuilder<> &Builder = OMPIRBuilder->Builder;
+  BasicBlock *Preheader = CLI->getPreheader();
+  Value *TripCount = CLI->getTripCount();
+
+  // After loop body outling, the loop body contains only set up
+  // of loop body argument structure and the call to the outlined
+  // loop body function. Firstly, we need to move setup of loop body args
+  // into loop preheader.
+  Preheader->splice(std::prev(Preheader->end()), CLI->getBody(),
+                    CLI->getBody()->begin(), std::prev(CLI->getBody()->end()));
+
+  // The next step is to remove the whole loop. We do not it need anymore.
+  // That's why make an unconditional branch from loop preheader to loop
+  // exit block
+  Builder.restoreIP({Preheader, Preheader->end()});
+  Preheader->getTerminator()->eraseFromParent();
+  Builder.CreateBr(CLI->getExit());
+
+  // Delete dead loop blocks
+  OpenMPIRBuilder::OutlineInfo CleanUpInfo;
+  SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
+  SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
+  CleanUpInfo.EntryBB = CLI->getHeader();
+  CleanUpInfo.ExitBB = CLI->getExit();
+  CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved);
+  DeleteDeadBlocks(BlocksToBeRemoved);
+
+  // Find the instruction which corresponds to loop body argument structure
+  // and remove the call to loop body function instruction.
+  Value *LoopBodyArg;
+  for (auto instIt = Preheader->begin(); instIt != Preheader->end(); ++instIt) {
+    if (CallInst *CallInstruction = dyn_cast<CallInst>(instIt)) {
+      if (CallInstruction->getCalledFunction() == &OutlinedFn) {
+        // Check in case no argument structure has been passed.
+        if (CallInstruction->arg_size() > 1)
+          LoopBodyArg = CallInstruction->getArgOperand(1);
+        else
+          LoopBodyArg = Constant::getNullValue(Builder.getPtrTy());
+        CallInstruction->eraseFromParent();
+        break;
+      }
+    }
+  }
----------------
jdoerfert wrote:

The loop is just to look for the call? Go from the OutlinedFn to it's single user instead.

https://github.com/llvm/llvm-project/pull/73360


More information about the llvm-commits mailing list