[Mlir-commits] [llvm] [mlir] [clang] [OpenMP] Migrate GPU Reductions CodeGen from Clang to OMPIRBuilder (PR #80343)

Jan Leyonberg llvmlistbot at llvm.org
Tue Feb 6 10:31:46 PST 2024


================
@@ -2042,35 +2057,1378 @@ OpenMPIRBuilder::createSection(const LocationDescription &Loc,
                               /*IsCancellable*/ true);
 }
 
-/// Create a function with a unique name and a "void (i8*, i8*)" signature in
-/// the given module and return it.
-Function *getFreshReductionFunc(Module &M) {
-  Type *VoidTy = Type::getVoidTy(M.getContext());
-  Type *Int8PtrTy = PointerType::getUnqual(M.getContext());
-  auto *FuncTy =
-      FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false);
-  return Function::Create(FuncTy, GlobalVariable::InternalLinkage,
-                          M.getDataLayout().getDefaultGlobalsAddressSpace(),
-                          ".omp.reduction.func", &M);
+static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
+  BasicBlock::iterator IT(I);
+  IT++;
+  return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
 }
 
-OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
-    const LocationDescription &Loc, InsertPointTy AllocaIP,
-    ArrayRef<ReductionInfo> ReductionInfos, bool IsNoWait) {
-  for (const ReductionInfo &RI : ReductionInfos) {
+void OpenMPIRBuilder::emitUsed(StringRef Name,
+                               std::vector<WeakTrackingVH> &List) {
+  if (List.empty())
+    return;
+
+  // Convert List to what ConstantArray needs.
+  SmallVector<Constant *, 8> UsedArray;
+  UsedArray.resize(List.size());
+  for (unsigned I = 0, E = List.size(); I != E; ++I)
+    UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+        cast<Constant>(&*List[I]), Builder.getPtrTy());
+
+  if (UsedArray.empty())
+    return;
+  ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
+
+  auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
+                                ConstantArray::get(ATy, UsedArray), Name);
+
+  GV->setSection("llvm.metadata");
+}
+
+Value *OpenMPIRBuilder::getGPUThreadID() {
+  return Builder.CreateCall(
+      getOrCreateRuntimeFunction(M,
+                                 OMPRTL___kmpc_get_hardware_thread_id_in_block),
+      {});
+}
+
+Value *OpenMPIRBuilder::getGPUWarpSize() {
+  return Builder.CreateCall(
+      getOrCreateRuntimeFunction(M, OMPRTL___kmpc_get_warp_size), {});
+}
+
+Value *OpenMPIRBuilder::getNVPTXWarpID() {
+  unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
+  return Builder.CreateAShr(getGPUThreadID(), LaneIDBits, "nvptx_warp_id");
+}
+
+Value *OpenMPIRBuilder::getNVPTXLaneID() {
+  unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
+  assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
+  unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
+  return Builder.CreateAnd(getGPUThreadID(), Builder.getInt32(LaneIDMask),
+                           "nvptx_lane_id");
+}
+
+Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
+                                        Type *ToType) {
+  Type *FromType = From->getType();
+  uint64_t FromSize = M.getDataLayout().getTypeStoreSize(FromType);
+  uint64_t ToSize = M.getDataLayout().getTypeStoreSize(ToType);
+  assert(FromSize > 0 && "From size must be greater than zero");
+  assert(ToSize > 0 && "To size must be greater than zero");
+  if (FromType == ToType)
+    return From;
+  if (FromSize == ToSize)
+    return Builder.CreateBitCast(From, ToType);
+  if (ToType->isIntegerTy() && FromType->isIntegerTy())
+    return Builder.CreateIntCast(From, ToType, /*isSigned*/ true);
+  InsertPointTy SaveIP = Builder.saveIP();
+  Builder.restoreIP(AllocaIP);
+  Value *CastItem = Builder.CreateAlloca(ToType);
+  Builder.restoreIP(SaveIP);
+
+  Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      CastItem, FromType->getPointerTo());
+  Builder.CreateStore(From, ValCastItem);
+  return Builder.CreateLoad(ToType, CastItem);
+}
+
+Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
+                                                     Value *Element,
+                                                     Type *ElementType,
+                                                     Value *Offset) {
+  uint64_t Size = M.getDataLayout().getTypeStoreSize(ElementType);
+  assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
+
+  // Cast all types to 32- or 64-bit values before calling shuffle routines.
+  Type *CastTy = Builder.getIntNTy(Size <= 4 ? 32 : 64);
+  Value *ElemCast = castValueToType(AllocaIP, Element, CastTy);
+  Value *WarpSize =
+      Builder.CreateIntCast(getGPUWarpSize(), Builder.getInt16Ty(), true);
+  Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
+      Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
+                : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
+  Value *WarpSizeCast =
+      Builder.CreateIntCast(WarpSize, Builder.getInt16Ty(), /*isSigned=*/true);
+  Value *ShuffleCall =
+      Builder.CreateCall(ShuffleFunc, {ElemCast, Offset, WarpSizeCast});
+  return castValueToType(AllocaIP, ShuffleCall, CastTy);
+}
+
+void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
+                                      Value *DstAddr, Type *ElemType,
+                                      Value *Offset, Type *ReductionArrayTy) {
+  uint64_t Size = M.getDataLayout().getTypeStoreSize(ElemType);
+  // Create the loop over the big sized data.
+  // ptr = (void*)Elem;
+  // ptrEnd = (void*) Elem + 1;
+  // Step = 8;
+  // while (ptr + Step < ptrEnd)
+  //   shuffle((int64_t)*ptr);
+  // Step = 4;
+  // while (ptr + Step < ptrEnd)
+  //   shuffle((int32_t)*ptr);
+  // ...
+  Type *IndexTy = Builder.getIndexTy(
+      M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+  Value *ElemPtr = DstAddr;
+  Value *Ptr = SrcAddr;
+  for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
+    if (Size < IntSize)
+      continue;
+    Type *IntType = Builder.getIntNTy(IntSize * 8);
+    Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        Ptr, IntType->getPointerTo(), Ptr->getName() + ".ascast");
+    Value *SrcAddrGEP =
+        Builder.CreateGEP(ElemType, SrcAddr, {ConstantInt::get(IndexTy, 1)});
+    ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        ElemPtr, IntType->getPointerTo(), ElemPtr->getName() + ".ascast");
+
+    Function *CurFunc = Builder.GetInsertBlock()->getParent();
+    if ((Size / IntSize) > 1) {
+      Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
+          SrcAddrGEP, Builder.getPtrTy());
+      BasicBlock *PreCondBB =
+          BasicBlock::Create(M.getContext(), ".shuffle.pre_cond");
+      BasicBlock *ThenBB = BasicBlock::Create(M.getContext(), ".shuffle.then");
+      BasicBlock *ExitBB = BasicBlock::Create(M.getContext(), ".shuffle.exit");
+      BasicBlock *CurrentBB = Builder.GetInsertBlock();
+      emitBlock(PreCondBB, CurFunc);
+      PHINode *PhiSrc =
+          Builder.CreatePHI(Ptr->getType(), /*NumReservedValues=*/2);
+      PhiSrc->addIncoming(Ptr, CurrentBB);
+      PHINode *PhiDest =
+          Builder.CreatePHI(ElemPtr->getType(), /*NumReservedValues=*/2);
+      PhiDest->addIncoming(ElemPtr, CurrentBB);
+      Ptr = PhiSrc;
+      ElemPtr = PhiDest;
+      Value *PtrDiff = Builder.CreatePtrDiff(
+          Builder.getInt8Ty(), PtrEnd,
+          Builder.CreatePointerBitCastOrAddrSpaceCast(Ptr, Builder.getPtrTy()));
+      Builder.CreateCondBr(
+          Builder.CreateICmpSGT(PtrDiff, Builder.getInt64(IntSize - 1)), ThenBB,
+          ExitBB);
+      emitBlock(ThenBB, CurFunc);
+      Value *Res = createRuntimeShuffleFunction(
+          AllocaIP,
+          Builder.CreateAlignedLoad(
+              IntType, Ptr, M.getDataLayout().getPrefTypeAlign(ElemType)),
+          IntType, Offset);
+      Builder.CreateAlignedStore(Res, ElemPtr,
+                                 M.getDataLayout().getPrefTypeAlign(ElemType));
+      Value *LocalPtr =
+          Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
+      Value *LocalElemPtr =
+          Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
+      PhiSrc->addIncoming(LocalPtr, ThenBB);
+      PhiDest->addIncoming(LocalElemPtr, ThenBB);
+      emitBranch(PreCondBB);
+      emitBlock(ExitBB, CurFunc);
+    } else {
+      Value *Res = createRuntimeShuffleFunction(
+          AllocaIP, Builder.CreateLoad(IntType, Ptr), IntType, Offset);
+      if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
+                                         Res->getType()->getScalarSizeInBits())
+        Res = Builder.CreateTrunc(Res, ElemType);
+      Builder.CreateStore(Res, ElemPtr);
+      Ptr = Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
+      ElemPtr =
+          Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
+    }
+    Size = Size % IntSize;
+  }
+}
+
+void OpenMPIRBuilder::emitReductionListCopy(
+    InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
+    ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
+    CopyOptionsTy CopyOptions) {
+  Type *IndexTy = Builder.getIndexTy(
+      M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+  Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
+
+  // Iterates, element-by-element, through the source Reduce list and
+  // make a copy.
+  for (auto En : enumerate(ReductionInfos)) {
+    const ReductionInfo &RI = En.value();
+    Value *SrcElementAddr = nullptr;
+    Value *DestElementAddr = nullptr;
+    Value *DestElementPtrAddr = nullptr;
+    // Should we shuffle in an element from a remote lane?
+    bool ShuffleInElement = false;
+    // Set to true to update the pointer in the dest Reduce list to a
+    // newly created element.
+    bool UpdateDestListPtr = false;
+
+    // Step 1.1: Get the address for the src element in the Reduce list.
+    Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
+        ReductionArrayTy, SrcBase,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    SrcElementAddr = Builder.CreateLoad(Builder.getPtrTy(), SrcElementPtrAddr);
+
+    // Step 1.2: Create a temporary to store the element in the destination
+    // Reduce list.
+    DestElementPtrAddr = Builder.CreateInBoundsGEP(
+        ReductionArrayTy, DestBase,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    switch (Action) {
+    case CopyAction::RemoteLaneToThread: {
+      InsertPointTy CurIP = Builder.saveIP();
+      Builder.restoreIP(AllocaIP);
+      AllocaInst *DestAlloca = Builder.CreateAlloca(RI.ElementType, nullptr,
+                                                    ".omp.reduction.element");
+      DestAlloca->setAlignment(
+          M.getDataLayout().getPrefTypeAlign(RI.ElementType));
+      DestElementAddr = DestAlloca;
+      DestElementAddr =
+          Builder.CreateAddrSpaceCast(DestElementAddr, Builder.getPtrTy(),
+                                      DestElementAddr->getName() + ".ascast");
+      Builder.restoreIP(CurIP);
+      ShuffleInElement = true;
+      UpdateDestListPtr = true;
+      break;
+    }
+    case CopyAction::ThreadCopy: {
+      DestElementAddr =
+          Builder.CreateLoad(Builder.getPtrTy(), DestElementPtrAddr);
+      break;
+    }
+    }
+
+    // Now that all active lanes have read the element in the
+    // Reduce list, shuffle over the value from the remote lane.
+    if (ShuffleInElement) {
+      shuffleAndStore(AllocaIP, SrcElementAddr, DestElementAddr, RI.ElementType,
+                      RemoteLaneOffset, ReductionArrayTy);
+    } else {
+      switch (RI.EvaluationKind) {
+      case EvaluationKindTy::Scalar: {
+        Value *Elem = Builder.CreateLoad(RI.ElementType, SrcElementAddr);
+        // Store the source element value to the dest element address.
+        Builder.CreateStore(Elem, DestElementAddr);
+        break;
+      }
+      case EvaluationKindTy::Complex: {
+        break;
+      }
+      case EvaluationKindTy::Aggregate: {
+        Value *SizeVal = Builder.getInt64(
+            M.getDataLayout().getTypeStoreSize(RI.ElementType));
+        Builder.CreateMemCpy(
+            DestElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
+            SrcElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
+            SizeVal, false);
+        break;
+      }
+      };
+    }
+
+    // Step 3.1: Modify reference in dest Reduce list as needed.
+    // Modifying the reference in Reduce list to point to the newly
+    // created element.  The element is live in the current function
+    // scope and that of functions it invokes (i.e., reduce_function).
+    // RemoteReduceData[i] = (void*)&RemoteElem
+    if (UpdateDestListPtr) {
+      Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+          DestElementAddr, Builder.getPtrTy(),
+          DestElementAddr->getName() + ".ascast");
+      Builder.CreateStore(CastDestAddr, DestElementPtrAddr);
+    }
+  }
+}
+
+Function *OpenMPIRBuilder::emitInterWarpCopyFunction(
+    const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
+    AttributeList FuncAttrs) {
+  InsertPointTy SavedIP = Builder.saveIP();
+  LLVMContext &Ctx = M.getContext();
+  FunctionType *FuncTy = FunctionType::get(
+      Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getInt32Ty()},
+      /* IsVarArg */ false);
+  Function *WcFunc =
+      Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+                       "_omp_reduction_inter_warp_copy_func", &M);
+  WcFunc->setAttributes(FuncAttrs);
+  WcFunc->addParamAttr(0, Attribute::NoUndef);
+  WcFunc->addParamAttr(1, Attribute::NoUndef);
+  BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", WcFunc);
+  Builder.SetInsertPoint(EntryBB);
+
+  // ReduceList: thread local Reduce list.
+  // At the stage of the computation when this function is called, partially
+  // aggregated values reside in the first lane of every active warp.
+  Argument *ReduceListArg = WcFunc->getArg(0);
+  // NumWarps: number of warps active in the parallel region.  This could
+  // be smaller than 32 (max warps in a CTA) for partial block reduction.
+  Argument *NumWarpsArg = WcFunc->getArg(1);
+
+  // This array is used as a medium to transfer, one reduce element at a time,
+  // the data from the first lane of every warp to lanes in the first warp
+  // in order to perform the final step of a reduction in a parallel region
+  // (reduction across warps).  The array is placed in NVPTX __shared__ memory
+  // for reduced latency, as well as to have a distinct copy for concurrently
+  // executing target regions.  The array is declared with common linkage so
+  // as to be shared across compilation units.
+  StringRef TransferMediumName =
+      "__openmp_nvptx_data_transfer_temporary_storage";
+  GlobalVariable *TransferMedium = M.getGlobalVariable(TransferMediumName);
+  unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
+  ArrayType *ArrayTy = ArrayType::get(Builder.getInt32Ty(), WarpSize);
+  if (!TransferMedium) {
+    TransferMedium = new GlobalVariable(
+        M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
+        UndefValue::get(ArrayTy), TransferMediumName,
+        /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
+        /*AddressSpace=*/3);
+  }
+
+  uint32_t SrcLocStrSize;
+  Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+
+  // Get the CUDA thread id of the current OpenMP thread on the GPU.
+  Value *GPUThreadID = getGPUThreadID();
+  // nvptx_lane_id = nvptx_id % warpsize
+  Value *LaneID = getNVPTXLaneID();
+  // nvptx_warp_id = nvptx_id / warpsize
+  Value *WarpID = getNVPTXWarpID();
+
+  InsertPointTy AllocaIP =
+      InsertPointTy(Builder.GetInsertBlock(),
+                    Builder.GetInsertBlock()->getFirstInsertionPt());
+  Type *Arg0Type = ReduceListArg->getType();
+  Type *Arg1Type = NumWarpsArg->getType();
+  Builder.restoreIP(AllocaIP);
+  AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
+      Arg0Type, nullptr, ReduceListArg->getName() + ".addr");
+  AllocaInst *NumWarpsAlloca =
+      Builder.CreateAlloca(Arg1Type, nullptr, NumWarpsArg->getName() + ".addr");
+  Value *ThreadID =
+      getOrCreateThreadID(getOrCreateIdent(SrcLocStr, SrcLocStrSize));
+  Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ReduceListAlloca, Arg0Type, ReduceListAlloca->getName() + ".ascast");
+  Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      NumWarpsAlloca, Arg1Type->getPointerTo(),
+      NumWarpsAlloca->getName() + ".ascast");
+  Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
+  Builder.CreateStore(NumWarpsArg, NumWarpsAddrCast);
+  AllocaIP = getInsertPointAfterInstr(NumWarpsAlloca);
+  InsertPointTy CodeGenIP =
+      getInsertPointAfterInstr(&Builder.GetInsertBlock()->back());
+  Builder.restoreIP(CodeGenIP);
+
+  Value *ReduceList =
+      Builder.CreateLoad(Builder.getPtrTy(), ReduceListAddrCast);
+
+  for (auto En : enumerate(ReductionInfos)) {
+    //
+    // Warp master copies reduce element to transfer medium in __shared__
+    // memory.
+    //
+    const ReductionInfo &RI = En.value();
+    unsigned RealTySize = M.getDataLayout().getTypeAllocSize(RI.ElementType);
+    for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
+      Type *CType = Builder.getIntNTy(TySize * 8);
+
+      unsigned NumIters = RealTySize / TySize;
+      if (NumIters == 0)
+        continue;
+      Value *Cnt = nullptr;
+      Value *CntAddr = nullptr;
+      BasicBlock *PrecondBB = nullptr;
+      BasicBlock *ExitBB = nullptr;
+      if (NumIters > 1) {
+        CodeGenIP = Builder.saveIP();
+        Builder.restoreIP(AllocaIP);
+        CntAddr =
+            Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, ".cnt.addr");
+
+        CntAddr = Builder.CreateAddrSpaceCast(CntAddr, Builder.getPtrTy(),
+                                              CntAddr->getName() + ".ascast");
+        Builder.restoreIP(CodeGenIP);
+        Builder.CreateStore(Constant::getNullValue(Builder.getInt32Ty()),
+                            CntAddr,
+                            /*Volatile=*/false);
+        PrecondBB = BasicBlock::Create(Ctx, "precond");
+        ExitBB = BasicBlock::Create(Ctx, "exit");
+        BasicBlock *BodyBB = BasicBlock::Create(Ctx, "body");
+        emitBlock(PrecondBB, Builder.GetInsertBlock()->getParent());
+        Cnt = Builder.CreateLoad(Builder.getInt32Ty(), CntAddr,
+                                 /*Volatile=*/false);
+        Value *Cmp = Builder.CreateICmpULT(
+            Cnt, ConstantInt::get(Builder.getInt32Ty(), NumIters));
+        Builder.CreateCondBr(Cmp, BodyBB, ExitBB);
+        emitBlock(BodyBB, Builder.GetInsertBlock()->getParent());
+      }
+
+      // kmpc_barrier.
+      createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
+                    omp::Directive::OMPD_unknown,
+                    /* ForceSimpleCall */ false,
+                    /* CheckCancelFlag */ true, ThreadID);
+      BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
+      BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
+      BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
+
+      // if (lane_id  == 0)
+      Value *IsWarpMaster = Builder.CreateIsNull(LaneID, "warp_master");
+      Builder.CreateCondBr(IsWarpMaster, ThenBB, ElseBB);
+      emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
+
+      // Reduce element = LocalReduceList[i]
+      auto *RedListArrayTy =
+          ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+      Type *IndexTy = Builder.getIndexTy(
+          M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+      Value *ElemPtrPtr =
+          Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
+                                    {ConstantInt::get(IndexTy, 0),
+                                     ConstantInt::get(IndexTy, En.index())});
+      // elemptr = ((CopyType*)(elemptrptr)) + I
+      Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
+      if (NumIters > 1)
+        ElemPtr = Builder.CreateGEP(Builder.getInt32Ty(), ElemPtr, Cnt);
+
+      // Get pointer to location in transfer medium.
+      // MediumPtr = &medium[warp_id]
+      Value *MediumPtr = Builder.CreateInBoundsGEP(
+          ArrayTy, TransferMedium, {Builder.getInt64(0), WarpID});
+      // elem = *elemptr
+      //*MediumPtr = elem
+      Value *Elem = Builder.CreateLoad(CType, ElemPtr);
+      // Store the source element value to the dest element address.
+      Builder.CreateStore(Elem, MediumPtr,
+                          /*IsVolatile*/ true);
+      Builder.CreateBr(MergeBB);
+
+      // else
+      emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
+      Builder.CreateBr(MergeBB);
+
+      // endif
+      emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
+      createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
+                    omp::Directive::OMPD_unknown,
+                    /* ForceSimpleCall */ false,
+                    /* CheckCancelFlag */ true, ThreadID);
+
+      // Warp 0 copies reduce element from transfer medium
+      BasicBlock *W0ThenBB = BasicBlock::Create(Ctx, "then");
+      BasicBlock *W0ElseBB = BasicBlock::Create(Ctx, "else");
+      BasicBlock *W0MergeBB = BasicBlock::Create(Ctx, "ifcont");
+
+      Value *NumWarpsVal =
+          Builder.CreateLoad(Builder.getInt32Ty(), NumWarpsAddrCast);
+      // Up to 32 threads in warp 0 are active.
+      Value *IsActiveThread =
+          Builder.CreateICmpULT(GPUThreadID, NumWarpsVal, "is_active_thread");
+      Builder.CreateCondBr(IsActiveThread, W0ThenBB, W0ElseBB);
+
+      emitBlock(W0ThenBB, Builder.GetInsertBlock()->getParent());
+
+      // SecMediumPtr = &medium[tid]
+      // SrcMediumVal = *SrcMediumPtr
+      Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
+          ArrayTy, TransferMedium, {Builder.getInt64(0), GPUThreadID});
+      // TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
+      Value *TargetElemPtrPtr =
+          Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
+                                    {ConstantInt::get(IndexTy, 0),
+                                     ConstantInt::get(IndexTy, En.index())});
+      Value *TargetElemPtrVal =
+          Builder.CreateLoad(Builder.getPtrTy(), TargetElemPtrPtr);
+      Value *TargetElemPtr = TargetElemPtrVal;
+      if (NumIters > 1)
+        TargetElemPtr =
+            Builder.CreateGEP(Builder.getInt32Ty(), TargetElemPtr, Cnt);
+
+      // *TargetElemPtr = SrcMediumVal;
+      Value *SrcMediumValue =
+          Builder.CreateLoad(CType, SrcMediumPtrVal, /*IsVolatile*/ true);
+      Builder.CreateStore(SrcMediumValue, TargetElemPtr);
+      Builder.CreateBr(W0MergeBB);
+
+      emitBlock(W0ElseBB, Builder.GetInsertBlock()->getParent());
+      Builder.CreateBr(W0MergeBB);
+
+      emitBlock(W0MergeBB, Builder.GetInsertBlock()->getParent());
+
+      if (NumIters > 1) {
+        Cnt = Builder.CreateNSWAdd(
+            Cnt, ConstantInt::get(Builder.getInt32Ty(), /*V=*/1));
+        Builder.CreateStore(Cnt, CntAddr, /*Volatile=*/false);
+
+        auto *CurFn = Builder.GetInsertBlock()->getParent();
+        emitBranch(PrecondBB);
+        emitBlock(ExitBB, CurFn);
+      }
+      RealTySize %= TySize;
+    }
+  }
+
+  Builder.CreateRetVoid();
+  Builder.restoreIP(SavedIP);
+
+  return WcFunc;
+}
+
+Function *OpenMPIRBuilder::emitShuffleAndReduceFunction(
+    ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
+    AttributeList FuncAttrs) {
+  LLVMContext &Ctx = M.getContext();
+  FunctionType *FuncTy =
+      FunctionType::get(Builder.getVoidTy(),
+                        {Builder.getPtrTy(), Builder.getInt16Ty(),
+                         Builder.getInt16Ty(), Builder.getInt16Ty()},
+                        /* IsVarArg */ false);
+  Function *SarFunc =
+      Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+                       "_omp_reduction_shuffle_and_reduce_func", &M);
+  SarFunc->setAttributes(FuncAttrs);
+  SarFunc->addParamAttr(0, Attribute::NoUndef);
+  SarFunc->addParamAttr(1, Attribute::NoUndef);
+  SarFunc->addParamAttr(2, Attribute::NoUndef);
+  SarFunc->addParamAttr(3, Attribute::NoUndef);
+  SarFunc->addParamAttr(1, Attribute::SExt);
+  SarFunc->addParamAttr(2, Attribute::SExt);
+  SarFunc->addParamAttr(3, Attribute::SExt);
+  BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", SarFunc);
+  Builder.SetInsertPoint(EntryBB);
+
+  // Thread local Reduce list used to host the values of data to be reduced.
+  Argument *ReduceListArg = SarFunc->getArg(0);
+  // Current lane id; could be logical.
+  Argument *LaneIDArg = SarFunc->getArg(1);
+  // Offset of the remote source lane relative to the current lane.
+  Argument *RemoteLaneOffsetArg = SarFunc->getArg(2);
+  // Algorithm version.  This is expected to be known at compile time.
+  Argument *AlgoVerArg = SarFunc->getArg(3);
+
+  Type *ReduceListArgType = ReduceListArg->getType();
+  Type *LaneIDArgType = LaneIDArg->getType();
+  Type *LaneIDArgPtrType = LaneIDArg->getType()->getPointerTo();
+  Value *ReduceListAlloca = Builder.CreateAlloca(
+      ReduceListArgType, nullptr, ReduceListArg->getName() + ".addr");
+  Value *LaneIdAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
+                                             LaneIDArg->getName() + ".addr");
+  Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
+      LaneIDArgType, nullptr, RemoteLaneOffsetArg->getName() + ".addr");
+  Value *AlgoVerAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
+                                              AlgoVerArg->getName() + ".addr");
+  ArrayType *RedListArrayTy =
+      ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+
+  // Create a local thread-private variable to host the Reduce list
+  // from a remote lane.
+  Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
+      RedListArrayTy, nullptr, ".omp.reduction.remote_reduce_list");
+
+  Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ReduceListAlloca, ReduceListArgType,
+      ReduceListAlloca->getName() + ".ascast");
+  Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      LaneIdAlloca, LaneIDArgPtrType, LaneIdAlloca->getName() + ".ascast");
+  Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      RemoteLaneOffsetAlloca, LaneIDArgPtrType,
+      RemoteLaneOffsetAlloca->getName() + ".ascast");
+  Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      AlgoVerAlloca, LaneIDArgPtrType, AlgoVerAlloca->getName() + ".ascast");
+  Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      RemoteReductionListAlloca, Builder.getPtrTy(),
+      RemoteReductionListAlloca->getName() + ".ascast");
+
+  Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
+  Builder.CreateStore(LaneIDArg, LaneIdAddrCast);
+  Builder.CreateStore(RemoteLaneOffsetArg, RemoteLaneOffsetAddrCast);
+  Builder.CreateStore(AlgoVerArg, AlgoVerAddrCast);
+
+  Value *ReduceList = Builder.CreateLoad(ReduceListArgType, ReduceListAddrCast);
+  Value *LaneId = Builder.CreateLoad(LaneIDArgType, LaneIdAddrCast);
+  Value *RemoteLaneOffset =
+      Builder.CreateLoad(LaneIDArgType, RemoteLaneOffsetAddrCast);
+  Value *AlgoVer = Builder.CreateLoad(LaneIDArgType, AlgoVerAddrCast);
+
+  InsertPointTy AllocaIP = getInsertPointAfterInstr(RemoteReductionListAlloca);
+
+  // This loop iterates through the list of reduce elements and copies,
+  // element by element, from a remote lane in the warp to RemoteReduceList,
+  // hosted on the thread's stack.
+  emitReductionListCopy(
+      AllocaIP, CopyAction::RemoteLaneToThread, RedListArrayTy, ReductionInfos,
+      ReduceList, RemoteListAddrCast, {RemoteLaneOffset, nullptr, nullptr});
+
+  // The actions to be performed on the Remote Reduce list is dependent
+  // on the algorithm version.
+  //
+  //  if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
+  //  LaneId % 2 == 0 && Offset > 0):
+  //    do the reduction value aggregation
+  //
+  //  The thread local variable Reduce list is mutated in place to host the
+  //  reduced data, which is the aggregated value produced from local and
+  //  remote lanes.
+  //
+  //  Note that AlgoVer is expected to be a constant integer known at compile
+  //  time.
+  //  When AlgoVer==0, the first conjunction evaluates to true, making
+  //    the entire predicate true during compile time.
+  //  When AlgoVer==1, the second conjunction has only the second part to be
+  //    evaluated during runtime.  Other conjunctions evaluates to false
+  //    during compile time.
+  //  When AlgoVer==2, the third conjunction has only the second part to be
+  //    evaluated during runtime.  Other conjunctions evaluates to false
+  //    during compile time.
+  Value *CondAlgo0 = Builder.CreateIsNull(AlgoVer);
+  Value *Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
+  Value *LaneComp = Builder.CreateICmpULT(LaneId, RemoteLaneOffset);
+  Value *CondAlgo1 = Builder.CreateAnd(Algo1, LaneComp);
+  Value *Algo2 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(2));
+  Value *LaneIdAnd1 = Builder.CreateAnd(LaneId, Builder.getInt16(1));
+  Value *LaneIdComp = Builder.CreateIsNull(LaneIdAnd1);
+  Value *Algo2AndLaneIdComp = Builder.CreateAnd(Algo2, LaneIdComp);
+  Value *RemoteOffsetComp =
+      Builder.CreateICmpSGT(RemoteLaneOffset, Builder.getInt16(0));
+  Value *CondAlgo2 = Builder.CreateAnd(Algo2AndLaneIdComp, RemoteOffsetComp);
+  Value *CA0OrCA1 = Builder.CreateOr(CondAlgo0, CondAlgo1);
+  Value *CondReduce = Builder.CreateOr(CA0OrCA1, CondAlgo2);
+
+  BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
+  BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
+  BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
+
+  Builder.CreateCondBr(CondReduce, ThenBB, ElseBB);
+  emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
+  // reduce_function(LocalReduceList, RemoteReduceList)
+  Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ReduceList, Builder.getPtrTy());
+  Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      RemoteListAddrCast, Builder.getPtrTy());
+  Builder.CreateCall(ReduceFn, {LocalReduceListPtr, RemoteReduceListPtr})
+      ->addFnAttr(Attribute::NoUnwind);
+  Builder.CreateBr(MergeBB);
+
+  emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
+  Builder.CreateBr(MergeBB);
+
+  emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
+
+  // if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
+  // Reduce list.
+  Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
+  Value *LaneIdGtOffset = Builder.CreateICmpUGE(LaneId, RemoteLaneOffset);
+  Value *CondCopy = Builder.CreateAnd(Algo1, LaneIdGtOffset);
+
+  BasicBlock *CpyThenBB = BasicBlock::Create(Ctx, "then");
+  BasicBlock *CpyElseBB = BasicBlock::Create(Ctx, "else");
+  BasicBlock *CpyMergeBB = BasicBlock::Create(Ctx, "ifcont");
+  Builder.CreateCondBr(CondCopy, CpyThenBB, CpyElseBB);
+
+  emitBlock(CpyThenBB, Builder.GetInsertBlock()->getParent());
+  emitReductionListCopy(AllocaIP, CopyAction::ThreadCopy, RedListArrayTy,
+                        ReductionInfos, RemoteListAddrCast, ReduceList);
+  Builder.CreateBr(CpyMergeBB);
+
+  emitBlock(CpyElseBB, Builder.GetInsertBlock()->getParent());
+  Builder.CreateBr(CpyMergeBB);
+
+  emitBlock(CpyMergeBB, Builder.GetInsertBlock()->getParent());
+
+  Builder.CreateRetVoid();
+
+  return SarFunc;
+}
+
+Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
+    ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
+    AttributeList FuncAttrs) {
+  OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
+  LLVMContext &Ctx = M.getContext();
+  FunctionType *FuncTy = FunctionType::get(
+      Builder.getVoidTy(),
+      {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
+      /* IsVarArg */ false);
+  Function *LtGCFunc =
+      Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+                       "_omp_reduction_list_to_global_copy_func", &M);
+  LtGCFunc->setAttributes(FuncAttrs);
+  LtGCFunc->addParamAttr(0, Attribute::NoUndef);
+  LtGCFunc->addParamAttr(1, Attribute::NoUndef);
+  LtGCFunc->addParamAttr(2, Attribute::NoUndef);
+
+  BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
+  Builder.SetInsertPoint(EntryBlock);
+
+  // Buffer: global reduction buffer.
+  Argument *BufferArg = LtGCFunc->getArg(0);
+  // Idx: index of the buffer.
+  Argument *IdxArg = LtGCFunc->getArg(1);
+  // ReduceList: thread local Reduce list.
+  Argument *ReduceListArg = LtGCFunc->getArg(2);
+
+  Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
+                                                BufferArg->getName() + ".addr");
+  Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
+                                             IdxArg->getName() + ".addr");
+  Value *ReduceListArgAlloca = Builder.CreateAlloca(
+      Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
+  Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      BufferArgAlloca, Builder.getPtrTy(),
+      BufferArgAlloca->getName() + ".ascast");
+  Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
+  Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ReduceListArgAlloca, Builder.getPtrTy(),
+      ReduceListArgAlloca->getName() + ".ascast");
+
+  Builder.CreateStore(BufferArg, BufferArgAddrCast);
+  Builder.CreateStore(IdxArg, IdxArgAddrCast);
+  Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
+
+  Value *LocalReduceList =
+      Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+  Value *BufferArgVal =
+      Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
+  Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
+  Type *IndexTy = Builder.getIndexTy(
+      M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+  for (auto En : enumerate(ReductionInfos)) {
+    const ReductionInfo &RI = En.value();
+    auto *RedListArrayTy =
+        ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+    // Reduce element = LocalReduceList[i]
+    Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
+        RedListArrayTy, LocalReduceList,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    // elemptr = ((CopyType*)(elemptrptr)) + I
+    Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
+
+    // Global = Buffer.VD[Idx];
+    Value *BufferVD =
+        Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferArgVal, Idxs);
+    Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
+        ReductionsBufferTy, BufferVD, 0, En.index(), "sum");
+
+    switch (RI.EvaluationKind) {
+    case EvaluationKindTy::Scalar: {
+      Value *TargetElement = Builder.CreateLoad(RI.ElementType, ElemPtr);
+      Builder.CreateStore(TargetElement, GlobVal);
+      break;
+    }
+    case EvaluationKindTy::Complex: {
+      break;
+    }
+    case EvaluationKindTy::Aggregate:
+      Value *SizeVal =
+          Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
+      Builder.CreateMemCpy(
+          GlobVal, M.getDataLayout().getPrefTypeAlign(RI.ElementType), ElemPtr,
+          M.getDataLayout().getPrefTypeAlign(RI.ElementType), SizeVal, false);
+      break;
+    }
+  }
+
+  Builder.CreateRetVoid();
+  Builder.restoreIP(OldIP);
+  return LtGCFunc;
+}
+
+Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
+    ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
+    Type *ReductionsBufferTy, AttributeList FuncAttrs) {
+  OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
+  LLVMContext &Ctx = M.getContext();
+  FunctionType *FuncTy = FunctionType::get(
+      Builder.getVoidTy(),
+      {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
+      /* IsVarArg */ false);
+  Function *LtGRFunc =
+      Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+                       "_omp_reduction_list_to_global_reduce_func", &M);
+  LtGRFunc->setAttributes(FuncAttrs);
+  LtGRFunc->addParamAttr(0, Attribute::NoUndef);
+  LtGRFunc->addParamAttr(1, Attribute::NoUndef);
+  LtGRFunc->addParamAttr(2, Attribute::NoUndef);
+
+  BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
+  Builder.SetInsertPoint(EntryBlock);
+
+  // Buffer: global reduction buffer.
+  Argument *BufferArg = LtGRFunc->getArg(0);
+  // Idx: index of the buffer.
+  Argument *IdxArg = LtGRFunc->getArg(1);
+  // ReduceList: thread local Reduce list.
+  Argument *ReduceListArg = LtGRFunc->getArg(2);
+
+  Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
+                                                BufferArg->getName() + ".addr");
+  Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
+                                             IdxArg->getName() + ".addr");
+  Value *ReduceListArgAlloca = Builder.CreateAlloca(
+      Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
+  auto *RedListArrayTy =
+      ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+
+  // 1. Build a list of reduction variables.
+  // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
+  Value *LocalReduceList =
+      Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
+
+  Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      BufferArgAlloca, Builder.getPtrTy(),
+      BufferArgAlloca->getName() + ".ascast");
+  Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
+  Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ReduceListArgAlloca, Builder.getPtrTy(),
+      ReduceListArgAlloca->getName() + ".ascast");
+  Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      LocalReduceList, Builder.getPtrTy(),
+      LocalReduceList->getName() + ".ascast");
+
+  Builder.CreateStore(BufferArg, BufferArgAddrCast);
+  Builder.CreateStore(IdxArg, IdxArgAddrCast);
+  Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
+
+  Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
+  Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
+  Type *IndexTy = Builder.getIndexTy(
+      M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+  for (auto En : enumerate(ReductionInfos)) {
+    Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
+        RedListArrayTy, LocalReduceListAddrCast,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    Value *BufferVD =
+        Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
+    // Global = Buffer.VD[Idx];
+    Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
+        ReductionsBufferTy, BufferVD, 0, En.index(), "sum");
+    Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+  }
+
+  // Call reduce_function(GlobalReduceList, ReduceList)
+  Value *ReduceList =
+      Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+  Builder.CreateCall(ReduceFn, {LocalReduceListAddrCast, ReduceList})
+      ->addFnAttr(Attribute::NoUnwind);
+  Builder.CreateRetVoid();
+  Builder.restoreIP(OldIP);
+  return LtGRFunc;
+}
+
+Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
+    ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
+    AttributeList FuncAttrs) {
+  OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
+  LLVMContext &Ctx = M.getContext();
+  FunctionType *FuncTy = FunctionType::get(
+      Builder.getVoidTy(),
+      {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
+      /* IsVarArg */ false);
+  Function *LtGCFunc =
+      Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+                       "_omp_reduction_global_to_list_copy_func", &M);
+  LtGCFunc->setAttributes(FuncAttrs);
+  LtGCFunc->addParamAttr(0, Attribute::NoUndef);
+  LtGCFunc->addParamAttr(1, Attribute::NoUndef);
+  LtGCFunc->addParamAttr(2, Attribute::NoUndef);
+
+  BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
+  Builder.SetInsertPoint(EntryBlock);
+
+  // Buffer: global reduction buffer.
+  Argument *BufferArg = LtGCFunc->getArg(0);
+  // Idx: index of the buffer.
+  Argument *IdxArg = LtGCFunc->getArg(1);
+  // ReduceList: thread local Reduce list.
+  Argument *ReduceListArg = LtGCFunc->getArg(2);
+
+  Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
+                                                BufferArg->getName() + ".addr");
+  Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
+                                             IdxArg->getName() + ".addr");
+  Value *ReduceListArgAlloca = Builder.CreateAlloca(
+      Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
+  Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      BufferArgAlloca, Builder.getPtrTy(),
+      BufferArgAlloca->getName() + ".ascast");
+  Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
+  Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ReduceListArgAlloca, Builder.getPtrTy(),
+      ReduceListArgAlloca->getName() + ".ascast");
+  Builder.CreateStore(BufferArg, BufferArgAddrCast);
+  Builder.CreateStore(IdxArg, IdxArgAddrCast);
+  Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
+
+  Value *LocalReduceList =
+      Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+  Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
+  Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
+  Type *IndexTy = Builder.getIndexTy(
+      M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+  for (auto En : enumerate(ReductionInfos)) {
+    const OpenMPIRBuilder::ReductionInfo &RI = En.value();
+    auto *RedListArrayTy =
+        ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+    // Reduce element = LocalReduceList[i]
+    Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
+        RedListArrayTy, LocalReduceList,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    // elemptr = ((CopyType*)(elemptrptr)) + I
+    Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
+    // Global = Buffer.VD[Idx];
+    Value *BufferVD =
+        Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
+    Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
+        ReductionsBufferTy, BufferVD, 0, En.index(), "sum");
+
+    switch (RI.EvaluationKind) {
+    case EvaluationKindTy::Scalar: {
+      Value *TargetElement = Builder.CreateLoad(RI.ElementType, GlobValPtr);
+      Builder.CreateStore(TargetElement, ElemPtr);
+      break;
+    }
+    case EvaluationKindTy::Complex: {
+      // FIXME(Jan): Complex type
----------------
jsjodin wrote:

If there are no tests that check this, we need to add tests and fix the implementation. 

https://github.com/llvm/llvm-project/pull/80343


More information about the Mlir-commits mailing list