[Mlir-commits] [clang] [llvm] [mlir] [OpenMP] Migrate GPU Reductions CodeGen from Clang to OMPIRBuilder (PR #80343)
    Akash Banerjee 
    llvmlistbot at llvm.org
       
    Tue Apr 16 07:54:42 PDT 2024
    
    
  
================
@@ -2096,15 +2102,1408 @@ OpenMPIRBuilder::createSection(const LocationDescription &Loc,
                               /*IsCancellable*/ true);
 }
 
-/// Create a function with a unique name and a "void (i8*, i8*)" signature in
-/// the given module and return it.
-Function *getFreshReductionFunc(Module &M) {
+static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
+  BasicBlock::iterator IT(I);
+  IT++;
+  return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
+}
+
+void OpenMPIRBuilder::emitUsed(StringRef Name,
+                               std::vector<WeakTrackingVH> &List) {
+  if (List.empty())
+    return;
+
+  // Convert List to what ConstantArray needs.
+  SmallVector<Constant *, 8> UsedArray;
+  UsedArray.resize(List.size());
+  for (unsigned I = 0, E = List.size(); I != E; ++I)
+    UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+        cast<Constant>(&*List[I]), Builder.getPtrTy());
+
+  if (UsedArray.empty())
+    return;
+  ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
+
+  auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
+                                ConstantArray::get(ATy, UsedArray), Name);
+
+  GV->setSection("llvm.metadata");
+}
+
+Value *OpenMPIRBuilder::getGPUThreadID() {
+  return Builder.CreateCall(
+      getOrCreateRuntimeFunction(M,
+                                 OMPRTL___kmpc_get_hardware_thread_id_in_block),
+      {});
+}
+
+Value *OpenMPIRBuilder::getGPUWarpSize() {
+  return Builder.CreateCall(
+      getOrCreateRuntimeFunction(M, OMPRTL___kmpc_get_warp_size), {});
+}
+
+Value *OpenMPIRBuilder::getNVPTXWarpID() {
+  unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
+  return Builder.CreateAShr(getGPUThreadID(), LaneIDBits, "nvptx_warp_id");
+}
+
+Value *OpenMPIRBuilder::getNVPTXLaneID() {
+  unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
+  assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
+  unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
+  return Builder.CreateAnd(getGPUThreadID(), Builder.getInt32(LaneIDMask),
+                           "nvptx_lane_id");
+}
+
+Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
+                                        Type *ToType) {
+  Type *FromType = From->getType();
+  uint64_t FromSize = M.getDataLayout().getTypeStoreSize(FromType);
+  uint64_t ToSize = M.getDataLayout().getTypeStoreSize(ToType);
+  assert(FromSize > 0 && "From size must be greater than zero");
+  assert(ToSize > 0 && "To size must be greater than zero");
+  if (FromType == ToType)
+    return From;
+  if (FromSize == ToSize)
+    return Builder.CreateBitCast(From, ToType);
+  if (ToType->isIntegerTy() && FromType->isIntegerTy())
+    return Builder.CreateIntCast(From, ToType, /*isSigned*/ true);
+  InsertPointTy SaveIP = Builder.saveIP();
+  Builder.restoreIP(AllocaIP);
+  Value *CastItem = Builder.CreateAlloca(ToType);
+  Builder.restoreIP(SaveIP);
+
+  Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      CastItem, FromType->getPointerTo());
+  Builder.CreateStore(From, ValCastItem);
+  return Builder.CreateLoad(ToType, CastItem);
+}
+
+Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
+                                                     Value *Element,
+                                                     Type *ElementType,
+                                                     Value *Offset) {
+  uint64_t Size = M.getDataLayout().getTypeStoreSize(ElementType);
+  assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
+
+  // Cast all types to 32- or 64-bit values before calling shuffle routines.
+  Type *CastTy = Builder.getIntNTy(Size <= 4 ? 32 : 64);
+  Value *ElemCast = castValueToType(AllocaIP, Element, CastTy);
+  Value *WarpSize =
+      Builder.CreateIntCast(getGPUWarpSize(), Builder.getInt16Ty(), true);
+  Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
+      Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
+                : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
+  Value *WarpSizeCast =
+      Builder.CreateIntCast(WarpSize, Builder.getInt16Ty(), /*isSigned=*/true);
+  Value *ShuffleCall =
+      Builder.CreateCall(ShuffleFunc, {ElemCast, Offset, WarpSizeCast});
+  return castValueToType(AllocaIP, ShuffleCall, CastTy);
+}
+
+void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
+                                      Value *DstAddr, Type *ElemType,
+                                      Value *Offset, Type *ReductionArrayTy) {
+  uint64_t Size = M.getDataLayout().getTypeStoreSize(ElemType);
+  // Create the loop over the big sized data.
+  // ptr = (void*)Elem;
+  // ptrEnd = (void*) Elem + 1;
+  // Step = 8;
+  // while (ptr + Step < ptrEnd)
+  //   shuffle((int64_t)*ptr);
+  // Step = 4;
+  // while (ptr + Step < ptrEnd)
+  //   shuffle((int32_t)*ptr);
+  // ...
+  Type *IndexTy = Builder.getIndexTy(
+      M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+  Value *ElemPtr = DstAddr;
+  Value *Ptr = SrcAddr;
+  for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
+    if (Size < IntSize)
+      continue;
+    Type *IntType = Builder.getIntNTy(IntSize * 8);
+    Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        Ptr, IntType->getPointerTo(), Ptr->getName() + ".ascast");
+    Value *SrcAddrGEP =
+        Builder.CreateGEP(ElemType, SrcAddr, {ConstantInt::get(IndexTy, 1)});
+    ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        ElemPtr, IntType->getPointerTo(), ElemPtr->getName() + ".ascast");
+
+    Function *CurFunc = Builder.GetInsertBlock()->getParent();
+    if ((Size / IntSize) > 1) {
+      Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
+          SrcAddrGEP, Builder.getPtrTy());
+      BasicBlock *PreCondBB =
+          BasicBlock::Create(M.getContext(), ".shuffle.pre_cond");
+      BasicBlock *ThenBB = BasicBlock::Create(M.getContext(), ".shuffle.then");
+      BasicBlock *ExitBB = BasicBlock::Create(M.getContext(), ".shuffle.exit");
+      BasicBlock *CurrentBB = Builder.GetInsertBlock();
+      emitBlock(PreCondBB, CurFunc);
+      PHINode *PhiSrc =
+          Builder.CreatePHI(Ptr->getType(), /*NumReservedValues=*/2);
+      PhiSrc->addIncoming(Ptr, CurrentBB);
+      PHINode *PhiDest =
+          Builder.CreatePHI(ElemPtr->getType(), /*NumReservedValues=*/2);
+      PhiDest->addIncoming(ElemPtr, CurrentBB);
+      Ptr = PhiSrc;
+      ElemPtr = PhiDest;
+      Value *PtrDiff = Builder.CreatePtrDiff(
+          Builder.getInt8Ty(), PtrEnd,
+          Builder.CreatePointerBitCastOrAddrSpaceCast(Ptr, Builder.getPtrTy()));
+      Builder.CreateCondBr(
+          Builder.CreateICmpSGT(PtrDiff, Builder.getInt64(IntSize - 1)), ThenBB,
+          ExitBB);
+      emitBlock(ThenBB, CurFunc);
+      Value *Res = createRuntimeShuffleFunction(
+          AllocaIP,
+          Builder.CreateAlignedLoad(
+              IntType, Ptr, M.getDataLayout().getPrefTypeAlign(ElemType)),
+          IntType, Offset);
+      Builder.CreateAlignedStore(Res, ElemPtr,
+                                 M.getDataLayout().getPrefTypeAlign(ElemType));
+      Value *LocalPtr =
+          Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
+      Value *LocalElemPtr =
+          Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
+      PhiSrc->addIncoming(LocalPtr, ThenBB);
+      PhiDest->addIncoming(LocalElemPtr, ThenBB);
+      emitBranch(PreCondBB);
+      emitBlock(ExitBB, CurFunc);
+    } else {
+      Value *Res = createRuntimeShuffleFunction(
+          AllocaIP, Builder.CreateLoad(IntType, Ptr), IntType, Offset);
+      if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
+                                         Res->getType()->getScalarSizeInBits())
+        Res = Builder.CreateTrunc(Res, ElemType);
+      Builder.CreateStore(Res, ElemPtr);
+      Ptr = Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
+      ElemPtr =
+          Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
+    }
+    Size = Size % IntSize;
+  }
+}
+
+void OpenMPIRBuilder::emitReductionListCopy(
+    InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
+    ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
+    CopyOptionsTy CopyOptions) {
+  Type *IndexTy = Builder.getIndexTy(
+      M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+  Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
+
+  // Iterates, element-by-element, through the source Reduce list and
+  // make a copy.
+  for (auto En : enumerate(ReductionInfos)) {
+    const ReductionInfo &RI = En.value();
+    Value *SrcElementAddr = nullptr;
+    Value *DestElementAddr = nullptr;
+    Value *DestElementPtrAddr = nullptr;
+    // Should we shuffle in an element from a remote lane?
+    bool ShuffleInElement = false;
+    // Set to true to update the pointer in the dest Reduce list to a
+    // newly created element.
+    bool UpdateDestListPtr = false;
+
+    // Step 1.1: Get the address for the src element in the Reduce list.
+    Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
+        ReductionArrayTy, SrcBase,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    SrcElementAddr = Builder.CreateLoad(Builder.getPtrTy(), SrcElementPtrAddr);
+
+    // Step 1.2: Create a temporary to store the element in the destination
+    // Reduce list.
+    DestElementPtrAddr = Builder.CreateInBoundsGEP(
+        ReductionArrayTy, DestBase,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    switch (Action) {
+    case CopyAction::RemoteLaneToThread: {
+      InsertPointTy CurIP = Builder.saveIP();
+      Builder.restoreIP(AllocaIP);
+      AllocaInst *DestAlloca = Builder.CreateAlloca(RI.ElementType, nullptr,
+                                                    ".omp.reduction.element");
+      DestAlloca->setAlignment(
+          M.getDataLayout().getPrefTypeAlign(RI.ElementType));
+      DestElementAddr = DestAlloca;
+      DestElementAddr =
+          Builder.CreateAddrSpaceCast(DestElementAddr, Builder.getPtrTy(),
+                                      DestElementAddr->getName() + ".ascast");
+      Builder.restoreIP(CurIP);
+      ShuffleInElement = true;
+      UpdateDestListPtr = true;
+      break;
+    }
+    case CopyAction::ThreadCopy: {
+      DestElementAddr =
+          Builder.CreateLoad(Builder.getPtrTy(), DestElementPtrAddr);
+      break;
+    }
+    }
+
+    // Now that all active lanes have read the element in the
+    // Reduce list, shuffle over the value from the remote lane.
+    if (ShuffleInElement) {
+      shuffleAndStore(AllocaIP, SrcElementAddr, DestElementAddr, RI.ElementType,
+                      RemoteLaneOffset, ReductionArrayTy);
+    } else {
+      switch (RI.EvaluationKind) {
+      case EvalKind::Scalar: {
+        Value *Elem = Builder.CreateLoad(RI.ElementType, SrcElementAddr);
+        // Store the source element value to the dest element address.
+        Builder.CreateStore(Elem, DestElementAddr);
+        break;
+      }
+      case EvalKind::Complex: {
+        Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
+            RI.ElementType, SrcElementAddr, 0, 0, ".realp");
+        Value *SrcReal = Builder.CreateLoad(
+            RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
+        Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
+            RI.ElementType, SrcElementAddr, 0, 1, ".imagp");
+        Value *SrcImg = Builder.CreateLoad(
+            RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
+
+        Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
+            RI.ElementType, DestElementAddr, 0, 0, ".realp");
+        Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
+            RI.ElementType, DestElementAddr, 0, 1, ".imagp");
+        Builder.CreateStore(SrcReal, DestRealPtr);
+        Builder.CreateStore(SrcImg, DestImgPtr);
+        break;
+      }
+      case EvalKind::Aggregate: {
+        Value *SizeVal = Builder.getInt64(
+            M.getDataLayout().getTypeStoreSize(RI.ElementType));
+        Builder.CreateMemCpy(
+            DestElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
+            SrcElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
+            SizeVal, false);
+        break;
+      }
+      };
+    }
+
+    // Step 3.1: Modify reference in dest Reduce list as needed.
+    // Modifying the reference in Reduce list to point to the newly
+    // created element.  The element is live in the current function
+    // scope and that of functions it invokes (i.e., reduce_function).
+    // RemoteReduceData[i] = (void*)&RemoteElem
+    if (UpdateDestListPtr) {
+      Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+          DestElementAddr, Builder.getPtrTy(),
+          DestElementAddr->getName() + ".ascast");
+      Builder.CreateStore(CastDestAddr, DestElementPtrAddr);
+    }
+  }
+}
+
+Function *OpenMPIRBuilder::emitInterWarpCopyFunction(
+    const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
+    AttributeList FuncAttrs) {
+  InsertPointTy SavedIP = Builder.saveIP();
+  LLVMContext &Ctx = M.getContext();
+  FunctionType *FuncTy = FunctionType::get(
+      Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getInt32Ty()},
+      /* IsVarArg */ false);
+  Function *WcFunc =
+      Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+                       "_omp_reduction_inter_warp_copy_func", &M);
+  WcFunc->setAttributes(FuncAttrs);
+  WcFunc->addParamAttr(0, Attribute::NoUndef);
+  WcFunc->addParamAttr(1, Attribute::NoUndef);
+  BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", WcFunc);
+  Builder.SetInsertPoint(EntryBB);
+
+  // ReduceList: thread local Reduce list.
+  // At the stage of the computation when this function is called, partially
+  // aggregated values reside in the first lane of every active warp.
+  Argument *ReduceListArg = WcFunc->getArg(0);
+  // NumWarps: number of warps active in the parallel region.  This could
+  // be smaller than 32 (max warps in a CTA) for partial block reduction.
+  Argument *NumWarpsArg = WcFunc->getArg(1);
+
+  // This array is used as a medium to transfer, one reduce element at a time,
+  // the data from the first lane of every warp to lanes in the first warp
+  // in order to perform the final step of a reduction in a parallel region
+  // (reduction across warps).  The array is placed in NVPTX __shared__ memory
+  // for reduced latency, as well as to have a distinct copy for concurrently
+  // executing target regions.  The array is declared with common linkage so
+  // as to be shared across compilation units.
+  StringRef TransferMediumName =
+      "__openmp_nvptx_data_transfer_temporary_storage";
+  GlobalVariable *TransferMedium = M.getGlobalVariable(TransferMediumName);
+  unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
+  ArrayType *ArrayTy = ArrayType::get(Builder.getInt32Ty(), WarpSize);
+  if (!TransferMedium) {
+    TransferMedium = new GlobalVariable(
+        M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
+        UndefValue::get(ArrayTy), TransferMediumName,
+        /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
+        /*AddressSpace=*/3);
+  }
+
+  uint32_t SrcLocStrSize;
+  Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+
+  // Get the CUDA thread id of the current OpenMP thread on the GPU.
+  Value *GPUThreadID = getGPUThreadID();
+  // nvptx_lane_id = nvptx_id % warpsize
+  Value *LaneID = getNVPTXLaneID();
+  // nvptx_warp_id = nvptx_id / warpsize
+  Value *WarpID = getNVPTXWarpID();
+
+  InsertPointTy AllocaIP =
+      InsertPointTy(Builder.GetInsertBlock(),
+                    Builder.GetInsertBlock()->getFirstInsertionPt());
+  Type *Arg0Type = ReduceListArg->getType();
+  Type *Arg1Type = NumWarpsArg->getType();
+  Builder.restoreIP(AllocaIP);
+  AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
+      Arg0Type, nullptr, ReduceListArg->getName() + ".addr");
+  AllocaInst *NumWarpsAlloca =
+      Builder.CreateAlloca(Arg1Type, nullptr, NumWarpsArg->getName() + ".addr");
+  Value *ThreadID =
+      getOrCreateThreadID(getOrCreateIdent(SrcLocStr, SrcLocStrSize));
+  Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ReduceListAlloca, Arg0Type, ReduceListAlloca->getName() + ".ascast");
+  Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      NumWarpsAlloca, Arg1Type->getPointerTo(),
+      NumWarpsAlloca->getName() + ".ascast");
+  Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
+  Builder.CreateStore(NumWarpsArg, NumWarpsAddrCast);
+  AllocaIP = getInsertPointAfterInstr(NumWarpsAlloca);
+  InsertPointTy CodeGenIP =
+      getInsertPointAfterInstr(&Builder.GetInsertBlock()->back());
+  Builder.restoreIP(CodeGenIP);
+
+  Value *ReduceList =
+      Builder.CreateLoad(Builder.getPtrTy(), ReduceListAddrCast);
+
+  for (auto En : enumerate(ReductionInfos)) {
+    //
+    // Warp master copies reduce element to transfer medium in __shared__
+    // memory.
+    //
+    const ReductionInfo &RI = En.value();
+    unsigned RealTySize = M.getDataLayout().getTypeAllocSize(RI.ElementType);
+    for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
+      Type *CType = Builder.getIntNTy(TySize * 8);
+
+      unsigned NumIters = RealTySize / TySize;
+      if (NumIters == 0)
+        continue;
+      Value *Cnt = nullptr;
+      Value *CntAddr = nullptr;
+      BasicBlock *PrecondBB = nullptr;
+      BasicBlock *ExitBB = nullptr;
+      if (NumIters > 1) {
+        CodeGenIP = Builder.saveIP();
+        Builder.restoreIP(AllocaIP);
+        CntAddr =
+            Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, ".cnt.addr");
+
+        CntAddr = Builder.CreateAddrSpaceCast(CntAddr, Builder.getPtrTy(),
+                                              CntAddr->getName() + ".ascast");
+        Builder.restoreIP(CodeGenIP);
+        Builder.CreateStore(Constant::getNullValue(Builder.getInt32Ty()),
+                            CntAddr,
+                            /*Volatile=*/false);
+        PrecondBB = BasicBlock::Create(Ctx, "precond");
+        ExitBB = BasicBlock::Create(Ctx, "exit");
+        BasicBlock *BodyBB = BasicBlock::Create(Ctx, "body");
+        emitBlock(PrecondBB, Builder.GetInsertBlock()->getParent());
+        Cnt = Builder.CreateLoad(Builder.getInt32Ty(), CntAddr,
+                                 /*Volatile=*/false);
+        Value *Cmp = Builder.CreateICmpULT(
+            Cnt, ConstantInt::get(Builder.getInt32Ty(), NumIters));
+        Builder.CreateCondBr(Cmp, BodyBB, ExitBB);
+        emitBlock(BodyBB, Builder.GetInsertBlock()->getParent());
+      }
+
+      // kmpc_barrier.
+      createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
+                    omp::Directive::OMPD_unknown,
+                    /* ForceSimpleCall */ false,
+                    /* CheckCancelFlag */ true, ThreadID);
+      BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
+      BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
+      BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
+
+      // if (lane_id  == 0)
+      Value *IsWarpMaster = Builder.CreateIsNull(LaneID, "warp_master");
+      Builder.CreateCondBr(IsWarpMaster, ThenBB, ElseBB);
+      emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
+
+      // Reduce element = LocalReduceList[i]
+      auto *RedListArrayTy =
+          ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+      Type *IndexTy = Builder.getIndexTy(
+          M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+      Value *ElemPtrPtr =
+          Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
+                                    {ConstantInt::get(IndexTy, 0),
+                                     ConstantInt::get(IndexTy, En.index())});
+      // elemptr = ((CopyType*)(elemptrptr)) + I
+      Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
+      if (NumIters > 1)
+        ElemPtr = Builder.CreateGEP(Builder.getInt32Ty(), ElemPtr, Cnt);
+
+      // Get pointer to location in transfer medium.
+      // MediumPtr = &medium[warp_id]
+      Value *MediumPtr = Builder.CreateInBoundsGEP(
+          ArrayTy, TransferMedium, {Builder.getInt64(0), WarpID});
+      // elem = *elemptr
+      //*MediumPtr = elem
+      Value *Elem = Builder.CreateLoad(CType, ElemPtr);
+      // Store the source element value to the dest element address.
+      Builder.CreateStore(Elem, MediumPtr,
+                          /*IsVolatile*/ true);
+      Builder.CreateBr(MergeBB);
+
+      // else
+      emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
+      Builder.CreateBr(MergeBB);
+
+      // endif
+      emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
+      createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
+                    omp::Directive::OMPD_unknown,
+                    /* ForceSimpleCall */ false,
+                    /* CheckCancelFlag */ true, ThreadID);
+
+      // Warp 0 copies reduce element from transfer medium
+      BasicBlock *W0ThenBB = BasicBlock::Create(Ctx, "then");
+      BasicBlock *W0ElseBB = BasicBlock::Create(Ctx, "else");
+      BasicBlock *W0MergeBB = BasicBlock::Create(Ctx, "ifcont");
+
+      Value *NumWarpsVal =
+          Builder.CreateLoad(Builder.getInt32Ty(), NumWarpsAddrCast);
+      // Up to 32 threads in warp 0 are active.
+      Value *IsActiveThread =
+          Builder.CreateICmpULT(GPUThreadID, NumWarpsVal, "is_active_thread");
+      Builder.CreateCondBr(IsActiveThread, W0ThenBB, W0ElseBB);
+
+      emitBlock(W0ThenBB, Builder.GetInsertBlock()->getParent());
+
+      // SecMediumPtr = &medium[tid]
+      // SrcMediumVal = *SrcMediumPtr
+      Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
+          ArrayTy, TransferMedium, {Builder.getInt64(0), GPUThreadID});
+      // TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
+      Value *TargetElemPtrPtr =
+          Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
+                                    {ConstantInt::get(IndexTy, 0),
+                                     ConstantInt::get(IndexTy, En.index())});
+      Value *TargetElemPtrVal =
+          Builder.CreateLoad(Builder.getPtrTy(), TargetElemPtrPtr);
+      Value *TargetElemPtr = TargetElemPtrVal;
+      if (NumIters > 1)
+        TargetElemPtr =
+            Builder.CreateGEP(Builder.getInt32Ty(), TargetElemPtr, Cnt);
+
+      // *TargetElemPtr = SrcMediumVal;
+      Value *SrcMediumValue =
+          Builder.CreateLoad(CType, SrcMediumPtrVal, /*IsVolatile*/ true);
+      Builder.CreateStore(SrcMediumValue, TargetElemPtr);
+      Builder.CreateBr(W0MergeBB);
+
+      emitBlock(W0ElseBB, Builder.GetInsertBlock()->getParent());
+      Builder.CreateBr(W0MergeBB);
+
+      emitBlock(W0MergeBB, Builder.GetInsertBlock()->getParent());
+
+      if (NumIters > 1) {
+        Cnt = Builder.CreateNSWAdd(
+            Cnt, ConstantInt::get(Builder.getInt32Ty(), /*V=*/1));
+        Builder.CreateStore(Cnt, CntAddr, /*Volatile=*/false);
+
+        auto *CurFn = Builder.GetInsertBlock()->getParent();
+        emitBranch(PrecondBB);
+        emitBlock(ExitBB, CurFn);
+      }
+      RealTySize %= TySize;
+    }
+  }
+
+  Builder.CreateRetVoid();
+  Builder.restoreIP(SavedIP);
+
+  return WcFunc;
+}
+
+Function *OpenMPIRBuilder::emitShuffleAndReduceFunction(
+    ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
+    AttributeList FuncAttrs) {
+  LLVMContext &Ctx = M.getContext();
+  FunctionType *FuncTy =
+      FunctionType::get(Builder.getVoidTy(),
+                        {Builder.getPtrTy(), Builder.getInt16Ty(),
+                         Builder.getInt16Ty(), Builder.getInt16Ty()},
+                        /* IsVarArg */ false);
+  Function *SarFunc =
+      Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+                       "_omp_reduction_shuffle_and_reduce_func", &M);
+  SarFunc->setAttributes(FuncAttrs);
+  SarFunc->addParamAttr(0, Attribute::NoUndef);
+  SarFunc->addParamAttr(1, Attribute::NoUndef);
+  SarFunc->addParamAttr(2, Attribute::NoUndef);
+  SarFunc->addParamAttr(3, Attribute::NoUndef);
+  SarFunc->addParamAttr(1, Attribute::SExt);
+  SarFunc->addParamAttr(2, Attribute::SExt);
+  SarFunc->addParamAttr(3, Attribute::SExt);
+  BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", SarFunc);
+  Builder.SetInsertPoint(EntryBB);
+
+  // Thread local Reduce list used to host the values of data to be reduced.
+  Argument *ReduceListArg = SarFunc->getArg(0);
+  // Current lane id; could be logical.
+  Argument *LaneIDArg = SarFunc->getArg(1);
+  // Offset of the remote source lane relative to the current lane.
+  Argument *RemoteLaneOffsetArg = SarFunc->getArg(2);
+  // Algorithm version.  This is expected to be known at compile time.
+  Argument *AlgoVerArg = SarFunc->getArg(3);
+
+  Type *ReduceListArgType = ReduceListArg->getType();
+  Type *LaneIDArgType = LaneIDArg->getType();
+  Type *LaneIDArgPtrType = LaneIDArg->getType()->getPointerTo();
+  Value *ReduceListAlloca = Builder.CreateAlloca(
+      ReduceListArgType, nullptr, ReduceListArg->getName() + ".addr");
+  Value *LaneIdAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
+                                             LaneIDArg->getName() + ".addr");
+  Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
+      LaneIDArgType, nullptr, RemoteLaneOffsetArg->getName() + ".addr");
+  Value *AlgoVerAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
+                                              AlgoVerArg->getName() + ".addr");
+  ArrayType *RedListArrayTy =
+      ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+
+  // Create a local thread-private variable to host the Reduce list
+  // from a remote lane.
+  Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
+      RedListArrayTy, nullptr, ".omp.reduction.remote_reduce_list");
+
+  Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ReduceListAlloca, ReduceListArgType,
+      ReduceListAlloca->getName() + ".ascast");
+  Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      LaneIdAlloca, LaneIDArgPtrType, LaneIdAlloca->getName() + ".ascast");
+  Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      RemoteLaneOffsetAlloca, LaneIDArgPtrType,
+      RemoteLaneOffsetAlloca->getName() + ".ascast");
+  Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      AlgoVerAlloca, LaneIDArgPtrType, AlgoVerAlloca->getName() + ".ascast");
+  Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      RemoteReductionListAlloca, Builder.getPtrTy(),
+      RemoteReductionListAlloca->getName() + ".ascast");
+
+  Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
+  Builder.CreateStore(LaneIDArg, LaneIdAddrCast);
+  Builder.CreateStore(RemoteLaneOffsetArg, RemoteLaneOffsetAddrCast);
+  Builder.CreateStore(AlgoVerArg, AlgoVerAddrCast);
+
+  Value *ReduceList = Builder.CreateLoad(ReduceListArgType, ReduceListAddrCast);
+  Value *LaneId = Builder.CreateLoad(LaneIDArgType, LaneIdAddrCast);
+  Value *RemoteLaneOffset =
+      Builder.CreateLoad(LaneIDArgType, RemoteLaneOffsetAddrCast);
+  Value *AlgoVer = Builder.CreateLoad(LaneIDArgType, AlgoVerAddrCast);
+
+  InsertPointTy AllocaIP = getInsertPointAfterInstr(RemoteReductionListAlloca);
+
+  // This loop iterates through the list of reduce elements and copies,
+  // element by element, from a remote lane in the warp to RemoteReduceList,
+  // hosted on the thread's stack.
+  emitReductionListCopy(
+      AllocaIP, CopyAction::RemoteLaneToThread, RedListArrayTy, ReductionInfos,
+      ReduceList, RemoteListAddrCast, {RemoteLaneOffset, nullptr, nullptr});
+
+  // The actions to be performed on the Remote Reduce list is dependent
+  // on the algorithm version.
+  //
+  //  if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
+  //  LaneId % 2 == 0 && Offset > 0):
+  //    do the reduction value aggregation
+  //
+  //  The thread local variable Reduce list is mutated in place to host the
+  //  reduced data, which is the aggregated value produced from local and
+  //  remote lanes.
+  //
+  //  Note that AlgoVer is expected to be a constant integer known at compile
+  //  time.
+  //  When AlgoVer==0, the first conjunction evaluates to true, making
+  //    the entire predicate true during compile time.
+  //  When AlgoVer==1, the second conjunction has only the second part to be
+  //    evaluated during runtime.  Other conjunctions evaluates to false
+  //    during compile time.
+  //  When AlgoVer==2, the third conjunction has only the second part to be
+  //    evaluated during runtime.  Other conjunctions evaluates to false
+  //    during compile time.
+  Value *CondAlgo0 = Builder.CreateIsNull(AlgoVer);
+  Value *Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
+  Value *LaneComp = Builder.CreateICmpULT(LaneId, RemoteLaneOffset);
+  Value *CondAlgo1 = Builder.CreateAnd(Algo1, LaneComp);
+  Value *Algo2 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(2));
+  Value *LaneIdAnd1 = Builder.CreateAnd(LaneId, Builder.getInt16(1));
+  Value *LaneIdComp = Builder.CreateIsNull(LaneIdAnd1);
+  Value *Algo2AndLaneIdComp = Builder.CreateAnd(Algo2, LaneIdComp);
+  Value *RemoteOffsetComp =
+      Builder.CreateICmpSGT(RemoteLaneOffset, Builder.getInt16(0));
+  Value *CondAlgo2 = Builder.CreateAnd(Algo2AndLaneIdComp, RemoteOffsetComp);
+  Value *CA0OrCA1 = Builder.CreateOr(CondAlgo0, CondAlgo1);
+  Value *CondReduce = Builder.CreateOr(CA0OrCA1, CondAlgo2);
+
+  BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
+  BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
+  BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
+
+  Builder.CreateCondBr(CondReduce, ThenBB, ElseBB);
+  emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
+  // reduce_function(LocalReduceList, RemoteReduceList)
+  Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ReduceList, Builder.getPtrTy());
+  Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      RemoteListAddrCast, Builder.getPtrTy());
+  Builder.CreateCall(ReduceFn, {LocalReduceListPtr, RemoteReduceListPtr})
+      ->addFnAttr(Attribute::NoUnwind);
+  Builder.CreateBr(MergeBB);
+
+  emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
+  Builder.CreateBr(MergeBB);
+
+  emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
+
+  // if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
+  // Reduce list.
+  Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
+  Value *LaneIdGtOffset = Builder.CreateICmpUGE(LaneId, RemoteLaneOffset);
+  Value *CondCopy = Builder.CreateAnd(Algo1, LaneIdGtOffset);
+
+  BasicBlock *CpyThenBB = BasicBlock::Create(Ctx, "then");
+  BasicBlock *CpyElseBB = BasicBlock::Create(Ctx, "else");
+  BasicBlock *CpyMergeBB = BasicBlock::Create(Ctx, "ifcont");
+  Builder.CreateCondBr(CondCopy, CpyThenBB, CpyElseBB);
+
+  emitBlock(CpyThenBB, Builder.GetInsertBlock()->getParent());
+  emitReductionListCopy(AllocaIP, CopyAction::ThreadCopy, RedListArrayTy,
+                        ReductionInfos, RemoteListAddrCast, ReduceList);
+  Builder.CreateBr(CpyMergeBB);
+
+  emitBlock(CpyElseBB, Builder.GetInsertBlock()->getParent());
+  Builder.CreateBr(CpyMergeBB);
+
+  emitBlock(CpyMergeBB, Builder.GetInsertBlock()->getParent());
+
+  Builder.CreateRetVoid();
+
+  return SarFunc;
+}
+
+Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
+    ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
+    AttributeList FuncAttrs) {
+  OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
+  LLVMContext &Ctx = M.getContext();
+  FunctionType *FuncTy = FunctionType::get(
+      Builder.getVoidTy(),
+      {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
+      /* IsVarArg */ false);
+  Function *LtGCFunc =
+      Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+                       "_omp_reduction_list_to_global_copy_func", &M);
+  LtGCFunc->setAttributes(FuncAttrs);
+  LtGCFunc->addParamAttr(0, Attribute::NoUndef);
+  LtGCFunc->addParamAttr(1, Attribute::NoUndef);
+  LtGCFunc->addParamAttr(2, Attribute::NoUndef);
+
+  BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
+  Builder.SetInsertPoint(EntryBlock);
+
+  // Buffer: global reduction buffer.
+  Argument *BufferArg = LtGCFunc->getArg(0);
+  // Idx: index of the buffer.
+  Argument *IdxArg = LtGCFunc->getArg(1);
+  // ReduceList: thread local Reduce list.
+  Argument *ReduceListArg = LtGCFunc->getArg(2);
+
+  Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
+                                                BufferArg->getName() + ".addr");
+  Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
+                                             IdxArg->getName() + ".addr");
+  Value *ReduceListArgAlloca = Builder.CreateAlloca(
+      Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
+  Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      BufferArgAlloca, Builder.getPtrTy(),
+      BufferArgAlloca->getName() + ".ascast");
+  Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
+  Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ReduceListArgAlloca, Builder.getPtrTy(),
+      ReduceListArgAlloca->getName() + ".ascast");
+
+  Builder.CreateStore(BufferArg, BufferArgAddrCast);
+  Builder.CreateStore(IdxArg, IdxArgAddrCast);
+  Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
+
+  Value *LocalReduceList =
+      Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+  Value *BufferArgVal =
+      Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
+  Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
+  Type *IndexTy = Builder.getIndexTy(
+      M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+  for (auto En : enumerate(ReductionInfos)) {
+    const ReductionInfo &RI = En.value();
+    auto *RedListArrayTy =
+        ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+    // Reduce element = LocalReduceList[i]
+    Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
+        RedListArrayTy, LocalReduceList,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    // elemptr = ((CopyType*)(elemptrptr)) + I
+    Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
+
+    // Global = Buffer.VD[Idx];
+    Value *BufferVD =
+        Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferArgVal, Idxs);
+    Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
+        ReductionsBufferTy, BufferVD, 0, En.index(), "sum");
+
+    switch (RI.EvaluationKind) {
+    case EvalKind::Scalar: {
+      Value *TargetElement = Builder.CreateLoad(RI.ElementType, ElemPtr);
+      Builder.CreateStore(TargetElement, GlobVal);
+      break;
+    }
+    case EvalKind::Complex: {
+      Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
+          RI.ElementType, ElemPtr, 0, 0, ".realp");
+      Value *SrcReal = Builder.CreateLoad(
+          RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
+      Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
+          RI.ElementType, ElemPtr, 0, 1, ".imagp");
+      Value *SrcImg = Builder.CreateLoad(
+          RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
+
+      Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
+          RI.ElementType, GlobVal, 0, 0, ".realp");
+      Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
+          RI.ElementType, GlobVal, 0, 1, ".imagp");
+      Builder.CreateStore(SrcReal, DestRealPtr);
+      Builder.CreateStore(SrcImg, DestImgPtr);
+      break;
+    }
+    case EvalKind::Aggregate: {
+      Value *SizeVal =
+          Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
+      Builder.CreateMemCpy(
+          GlobVal, M.getDataLayout().getPrefTypeAlign(RI.ElementType), ElemPtr,
+          M.getDataLayout().getPrefTypeAlign(RI.ElementType), SizeVal, false);
+      break;
+    }
+    }
+  }
+
+  Builder.CreateRetVoid();
+  Builder.restoreIP(OldIP);
+  return LtGCFunc;
+}
+
+Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
+    ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
+    Type *ReductionsBufferTy, AttributeList FuncAttrs) {
+  OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
+  LLVMContext &Ctx = M.getContext();
+  FunctionType *FuncTy = FunctionType::get(
+      Builder.getVoidTy(),
+      {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
+      /* IsVarArg */ false);
+  Function *LtGRFunc =
+      Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+                       "_omp_reduction_list_to_global_reduce_func", &M);
+  LtGRFunc->setAttributes(FuncAttrs);
+  LtGRFunc->addParamAttr(0, Attribute::NoUndef);
+  LtGRFunc->addParamAttr(1, Attribute::NoUndef);
+  LtGRFunc->addParamAttr(2, Attribute::NoUndef);
+
+  BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
+  Builder.SetInsertPoint(EntryBlock);
+
+  // Buffer: global reduction buffer.
+  Argument *BufferArg = LtGRFunc->getArg(0);
+  // Idx: index of the buffer.
+  Argument *IdxArg = LtGRFunc->getArg(1);
+  // ReduceList: thread local Reduce list.
+  Argument *ReduceListArg = LtGRFunc->getArg(2);
+
+  Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
+                                                BufferArg->getName() + ".addr");
+  Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
+                                             IdxArg->getName() + ".addr");
+  Value *ReduceListArgAlloca = Builder.CreateAlloca(
+      Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
+  auto *RedListArrayTy =
+      ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+
+  // 1. Build a list of reduction variables.
+  // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
+  Value *LocalReduceList =
+      Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
+
+  Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      BufferArgAlloca, Builder.getPtrTy(),
+      BufferArgAlloca->getName() + ".ascast");
+  Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
+  Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ReduceListArgAlloca, Builder.getPtrTy(),
+      ReduceListArgAlloca->getName() + ".ascast");
+  Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      LocalReduceList, Builder.getPtrTy(),
+      LocalReduceList->getName() + ".ascast");
+
+  Builder.CreateStore(BufferArg, BufferArgAddrCast);
+  Builder.CreateStore(IdxArg, IdxArgAddrCast);
+  Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
+
+  Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
+  Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
+  Type *IndexTy = Builder.getIndexTy(
+      M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+  for (auto En : enumerate(ReductionInfos)) {
+    Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
+        RedListArrayTy, LocalReduceListAddrCast,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    Value *BufferVD =
+        Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
+    // Global = Buffer.VD[Idx];
+    Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
+        ReductionsBufferTy, BufferVD, 0, En.index(), "sum");
+    Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+  }
+
+  // Call reduce_function(GlobalReduceList, ReduceList)
+  Value *ReduceList =
+      Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+  Builder.CreateCall(ReduceFn, {LocalReduceListAddrCast, ReduceList})
+      ->addFnAttr(Attribute::NoUnwind);
+  Builder.CreateRetVoid();
+  Builder.restoreIP(OldIP);
+  return LtGRFunc;
+}
+
+Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
+    ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
+    AttributeList FuncAttrs) {
+  OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
+  LLVMContext &Ctx = M.getContext();
+  FunctionType *FuncTy = FunctionType::get(
+      Builder.getVoidTy(),
+      {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
+      /* IsVarArg */ false);
+  Function *LtGCFunc =
+      Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+                       "_omp_reduction_global_to_list_copy_func", &M);
+  LtGCFunc->setAttributes(FuncAttrs);
+  LtGCFunc->addParamAttr(0, Attribute::NoUndef);
+  LtGCFunc->addParamAttr(1, Attribute::NoUndef);
+  LtGCFunc->addParamAttr(2, Attribute::NoUndef);
+
+  BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
+  Builder.SetInsertPoint(EntryBlock);
+
+  // Buffer: global reduction buffer.
+  Argument *BufferArg = LtGCFunc->getArg(0);
+  // Idx: index of the buffer.
+  Argument *IdxArg = LtGCFunc->getArg(1);
+  // ReduceList: thread local Reduce list.
+  Argument *ReduceListArg = LtGCFunc->getArg(2);
+
+  Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
+                                                BufferArg->getName() + ".addr");
+  Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
+                                             IdxArg->getName() + ".addr");
+  Value *ReduceListArgAlloca = Builder.CreateAlloca(
+      Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
+  Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      BufferArgAlloca, Builder.getPtrTy(),
+      BufferArgAlloca->getName() + ".ascast");
+  Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
+  Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ReduceListArgAlloca, Builder.getPtrTy(),
+      ReduceListArgAlloca->getName() + ".ascast");
+  Builder.CreateStore(BufferArg, BufferArgAddrCast);
+  Builder.CreateStore(IdxArg, IdxArgAddrCast);
+  Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
+
+  Value *LocalReduceList =
+      Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+  Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
+  Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
+  Type *IndexTy = Builder.getIndexTy(
+      M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+  for (auto En : enumerate(ReductionInfos)) {
+    const OpenMPIRBuilder::ReductionInfo &RI = En.value();
+    auto *RedListArrayTy =
+        ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+    // Reduce element = LocalReduceList[i]
+    Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
+        RedListArrayTy, LocalReduceList,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    // elemptr = ((CopyType*)(elemptrptr)) + I
+    Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
+    // Global = Buffer.VD[Idx];
+    Value *BufferVD =
+        Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
+    Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
+        ReductionsBufferTy, BufferVD, 0, En.index(), "sum");
+
+    switch (RI.EvaluationKind) {
+    case EvalKind::Scalar: {
+      Value *TargetElement = Builder.CreateLoad(RI.ElementType, GlobValPtr);
+      Builder.CreateStore(TargetElement, ElemPtr);
+      break;
+    }
+    case EvalKind::Complex: {
+      Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
+          RI.ElementType, GlobValPtr, 0, 0, ".realp");
+      Value *SrcReal = Builder.CreateLoad(
+          RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
+      Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
+          RI.ElementType, GlobValPtr, 0, 1, ".imagp");
+      Value *SrcImg = Builder.CreateLoad(
+          RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
+
+      Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
+          RI.ElementType, ElemPtr, 0, 0, ".realp");
+      Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
+          RI.ElementType, ElemPtr, 0, 1, ".imagp");
+      Builder.CreateStore(SrcReal, DestRealPtr);
+      Builder.CreateStore(SrcImg, DestImgPtr);
+      break;
+    }
+    case EvalKind::Aggregate: {
+      Value *SizeVal =
+          Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
+      Builder.CreateMemCpy(
+          ElemPtr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
+          GlobValPtr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
+          SizeVal, false);
+      break;
+    }
+    }
+  }
+
+  Builder.CreateRetVoid();
+  Builder.restoreIP(OldIP);
+  return LtGCFunc;
+}
+
+Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
+    ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
+    Type *ReductionsBufferTy, AttributeList FuncAttrs) {
+  OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
+  LLVMContext &Ctx = M.getContext();
+  auto *FuncTy = FunctionType::get(
+      Builder.getVoidTy(),
+      {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
+      /* IsVarArg */ false);
+  Function *LtGRFunc =
+      Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+                       "_omp_reduction_global_to_list_reduce_func", &M);
+  LtGRFunc->setAttributes(FuncAttrs);
+  LtGRFunc->addParamAttr(0, Attribute::NoUndef);
+  LtGRFunc->addParamAttr(1, Attribute::NoUndef);
+  LtGRFunc->addParamAttr(2, Attribute::NoUndef);
+
+  BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
+  Builder.SetInsertPoint(EntryBlock);
+
+  // Buffer: global reduction buffer.
+  Argument *BufferArg = LtGRFunc->getArg(0);
+  // Idx: index of the buffer.
+  Argument *IdxArg = LtGRFunc->getArg(1);
+  // ReduceList: thread local Reduce list.
+  Argument *ReduceListArg = LtGRFunc->getArg(2);
+
+  Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
+                                                BufferArg->getName() + ".addr");
+  Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
+                                             IdxArg->getName() + ".addr");
+  Value *ReduceListArgAlloca = Builder.CreateAlloca(
+      Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
+  ArrayType *RedListArrayTy =
+      ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+
+  // 1. Build a list of reduction variables.
+  // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
+  Value *LocalReduceList =
+      Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
+
+  Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      BufferArgAlloca, Builder.getPtrTy(),
+      BufferArgAlloca->getName() + ".ascast");
+  Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
+  Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ReduceListArgAlloca, Builder.getPtrTy(),
+      ReduceListArgAlloca->getName() + ".ascast");
+  Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      LocalReduceList, Builder.getPtrTy(),
+      LocalReduceList->getName() + ".ascast");
+
+  Builder.CreateStore(BufferArg, BufferArgAddrCast);
+  Builder.CreateStore(IdxArg, IdxArgAddrCast);
+  Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
+
+  Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
+  Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
+  Type *IndexTy = Builder.getIndexTy(
+      M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+  for (auto En : enumerate(ReductionInfos)) {
+    Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
+        RedListArrayTy, ReductionList,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    // Global = Buffer.VD[Idx];
+    Value *BufferVD =
+        Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
+    Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
+        ReductionsBufferTy, BufferVD, 0, En.index(), "sum");
+    Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+  }
+
+  // Call reduce_function(ReduceList, GlobalReduceList)
+  Value *ReduceList =
+      Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+  Builder.CreateCall(ReduceFn, {ReduceList, ReductionList})
+      ->addFnAttr(Attribute::NoUnwind);
+  Builder.CreateRetVoid();
+  Builder.restoreIP(OldIP);
+  return LtGRFunc;
+}
+
+std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
+  std::string Suffix =
+      createPlatformSpecificName({"omp", "reduction", "reduction_func"});
+  return (Name + Suffix).str();
+}
+
+Function *OpenMPIRBuilder::createReductionFunction(
+    StringRef ReducerName, ArrayRef<ReductionInfo> ReductionInfos,
+    ReductionGenCBKind ReductionGenCBTy, AttributeList FuncAttrs) {
+  auto *FuncTy = FunctionType::get(Builder.getVoidTy(),
+                                   {Builder.getPtrTy(), Builder.getPtrTy()},
+                                   /* IsVarArg */ false);
+  std::string Name = getReductionFuncName(ReducerName);
+  Function *ReductionFunc =
+      Function::Create(FuncTy, GlobalVariable::InternalLinkage, Name, &M);
+  ReductionFunc->setAttributes(FuncAttrs);
+  ReductionFunc->addParamAttr(0, Attribute::NoUndef);
+  ReductionFunc->addParamAttr(1, Attribute::NoUndef);
+  BasicBlock *EntryBB =
+      BasicBlock::Create(M.getContext(), "entry", ReductionFunc);
+  Builder.SetInsertPoint(EntryBB);
+
+  // Need to alloca memory here and deal with the pointers before getting
+  // LHS/RHS pointers out
+  Value *LHSArrayPtr = nullptr;
+  Value *RHSArrayPtr = nullptr;
+  Argument *Arg0 = ReductionFunc->getArg(0);
+  Argument *Arg1 = ReductionFunc->getArg(1);
+  Type *Arg0Type = Arg0->getType();
+  Type *Arg1Type = Arg1->getType();
+
+  Value *LHSAlloca =
+      Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
+  Value *RHSAlloca =
+      Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
+  Value *LHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      LHSAlloca, Arg0Type, LHSAlloca->getName() + ".ascast");
+  Value *RHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      RHSAlloca, Arg1Type, RHSAlloca->getName() + ".ascast");
+  Builder.CreateStore(Arg0, LHSAddrCast);
+  Builder.CreateStore(Arg1, RHSAddrCast);
+  LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
+  RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
+
+  Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+  Type *IndexTy = Builder.getIndexTy(
+      M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+  SmallVector<Value *> LHSPtrs, RHSPtrs;
+  for (auto En : enumerate(ReductionInfos)) {
+    const ReductionInfo &RI = En.value();
+    Value *RHSI8PtrPtr = Builder.CreateInBoundsGEP(
+        RedArrayTy, RHSArrayPtr,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
+    Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        RHSI8Ptr, RI.PrivateVariable->getType(),
+        RHSI8Ptr->getName() + ".ascast");
+
+    Value *LHSI8PtrPtr = Builder.CreateInBoundsGEP(
+        RedArrayTy, LHSArrayPtr,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
+    Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        LHSI8Ptr, RI.Variable->getType(), LHSI8Ptr->getName() + ".ascast");
+
+    if (ReductionGenCBTy == ReductionGenCBKind::Clang) {
+      LHSPtrs.emplace_back(LHSPtr);
+      RHSPtrs.emplace_back(RHSPtr);
+    } else {
+      Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
+      Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
+      Value *Reduced;
+      RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
+      if (!Builder.GetInsertBlock())
+        return ReductionFunc;
+      Builder.CreateStore(Reduced, LHSPtr);
+    }
+  }
+
+  if (ReductionGenCBTy == ReductionGenCBKind::Clang)
+    for (auto En : enumerate(ReductionInfos)) {
+      unsigned Index = En.index();
+      const ReductionInfo &RI = En.value();
+      Value *LHSFixupPtr, *RHSFixupPtr;
+      Builder.restoreIP(RI.ReductionGenClang(
+          Builder.saveIP(), Index, &LHSFixupPtr, &RHSFixupPtr, ReductionFunc));
+
+      // Fix the CallBack code genereated to use the correct Values for the LHS
+      // and RHS
+      LHSFixupPtr->replaceUsesWithIf(
+          LHSPtrs[Index], [ReductionFunc](const Use &U) {
+            return cast<Instruction>(U.getUser())->getParent()->getParent() ==
+                   ReductionFunc;
+          });
+      RHSFixupPtr->replaceUsesWithIf(
+          RHSPtrs[Index], [ReductionFunc](const Use &U) {
+            return cast<Instruction>(U.getUser())->getParent()->getParent() ==
+                   ReductionFunc;
+          });
+    }
+
+  Builder.CreateRetVoid();
+  return ReductionFunc;
+}
+
+static void
+checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
+                    bool IsGPU) {
+  for (const OpenMPIRBuilder::ReductionInfo &RI : ReductionInfos) {
+    (void)RI;
+    assert(RI.Variable && "expected non-null variable");
+    assert(RI.PrivateVariable && "expected non-null private variable");
+    assert((RI.ReductionGen || RI.ReductionGenClang) &&
+           "expected non-null reduction generator callback");
+    if (!IsGPU) {
+      assert(
+          RI.Variable->getType() == RI.PrivateVariable->getType() &&
+          "expected variables and their private equivalents to have the same "
+          "type");
+    }
+    assert(RI.Variable->getType()->isPointerTy() &&
+           "expected variables to be pointers");
+  }
+}
+
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductionsGPU(
+    const LocationDescription &Loc, InsertPointTy AllocaIP,
+    InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
+    bool IsNoWait, bool IsTeamsReduction, bool HasDistribute,
+    ReductionGenCBKind ReductionGenCBTy, std::optional<omp::GV> GridValue,
+    unsigned ReductionBufNum, Value *SrcLocInfo) {
+  if (!updateToLocation(Loc))
+    return InsertPointTy();
+  Builder.restoreIP(CodeGenIP);
+  checkReductionInfos(ReductionInfos, /*IsGPU*/ true);
+  LLVMContext &Ctx = M.getContext();
+
+  // Source location for the ident struct
+  if (!SrcLocInfo) {
+    uint32_t SrcLocStrSize;
+    Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+    SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+  }
+
+  if (ReductionInfos.size() == 0)
+    return Builder.saveIP();
+
+  Function *CurFunc = Builder.GetInsertBlock()->getParent();
+  AttributeList FuncAttrs;
+  AttrBuilder AttrBldr(Ctx);
+  for (auto Attr : CurFunc->getAttributes().getFnAttrs())
+    AttrBldr.addAttribute(Attr);
+  AttrBldr.removeAttribute(Attribute::OptimizeNone);
+  FuncAttrs = FuncAttrs.addFnAttributes(Ctx, AttrBldr);
+
+  Function *ReductionFunc = nullptr;
+  if (GLOBAL_ReductionFunc) {
+    ReductionFunc = GLOBAL_ReductionFunc;
+  } else {
+    CodeGenIP = Builder.saveIP();
+    ReductionFunc = createReductionFunction(
+        Builder.GetInsertBlock()->getParent()->getName(), ReductionInfos,
+        ReductionGenCBTy, FuncAttrs);
+    Builder.restoreIP(CodeGenIP);
+  }
+
+  // Set the grid value in the config needed for lowering later on
+  if (GridValue.has_value())
+    Config.setGridValue(GridValue.value());
+  else
+    Config.setGridValue(getGridValue(T, ReductionFunc));
+
+  uint32_t SrcLocStrSize;
+  Constant *SrcLocStr = getOrCreateDefaultSrcLocStr(SrcLocStrSize);
+  Value *RTLoc =
+      getOrCreateIdent(SrcLocStr, SrcLocStrSize, omp::IdentFlag(0), 0);
+
+  // Build res = __kmpc_reduce{_nowait}(<gtid>, <n>, sizeof(RedList),
+  // RedList, shuffle_reduce_func, interwarp_copy_func);
+  // or
+  // Build res = __kmpc_reduce_teams_nowait_simple(<loc>, <gtid>, <lck>);
+  Value *Res;
+
+  // 1. Build a list of reduction variables.
+  // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
+  auto Size = ReductionInfos.size();
+  Type *PtrTy = PointerType::getUnqual(Ctx);
+  Type *RedArrayTy = ArrayType::get(PtrTy, Size);
+  CodeGenIP = Builder.saveIP();
+  Builder.restoreIP(AllocaIP);
+  Value *ReductionListAlloca =
+      Builder.CreateAlloca(RedArrayTy, nullptr, ".omp.reduction.red_list");
+  Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ReductionListAlloca, PtrTy, ReductionListAlloca->getName() + ".ascast");
+  Builder.restoreIP(CodeGenIP);
+  Type *IndexTy = Builder.getIndexTy(
+      M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+  for (auto En : enumerate(ReductionInfos)) {
+    const ReductionInfo &RI = En.value();
+    Value *ElemPtr = Builder.CreateInBoundsGEP(
+        RedArrayTy, ReductionList,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    Value *CastElem =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(RI.PrivateVariable, PtrTy);
+    Builder.CreateStore(CastElem, ElemPtr);
+  }
+  CodeGenIP = Builder.saveIP();
+  Function *SarFunc =
+      emitShuffleAndReduceFunction(ReductionInfos, ReductionFunc, FuncAttrs);
+  Function *WcFunc = emitInterWarpCopyFunction(Loc, ReductionInfos, FuncAttrs);
+  Builder.restoreIP(CodeGenIP);
+
+  Value *RL = Builder.CreatePointerBitCastOrAddrSpaceCast(ReductionList, PtrTy);
+
+  unsigned MaxDataSize = 0;
+  SmallVector<Type *> ReductionTypeArgs;
+  for (auto En : enumerate(ReductionInfos)) {
+    auto Size = M.getDataLayout().getTypeStoreSize(En.value().ElementType);
+    if (Size > MaxDataSize)
+      MaxDataSize = Size;
+    ReductionTypeArgs.emplace_back(En.value().ElementType);
+  }
+  Value *ReductionDataSize =
+      Builder.getInt64(MaxDataSize * ReductionInfos.size());
+  if (!IsTeamsReduction) {
+    Value *SarFuncCast =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(SarFunc, PtrTy);
+    Value *WcFuncCast =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(WcFunc, PtrTy);
+    Value *Args[] = {RTLoc, ReductionDataSize, RL, SarFuncCast, WcFuncCast};
+    Function *Pv2Ptr = getOrCreateRuntimeFunctionPtr(
+        RuntimeFunction::OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2);
+    Res = Builder.CreateCall(Pv2Ptr, Args);
+  } else {
+    CodeGenIP = Builder.saveIP();
+    StructType *ReductionsBufferTy = StructType::create(
+        Ctx, ReductionTypeArgs, "struct._globalized_locals_ty");
+    Function *RedFixedBuferFn = getOrCreateRuntimeFunctionPtr(
+        RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
+    Function *LtGCFunc = emitListToGlobalCopyFunction(
+        ReductionInfos, ReductionsBufferTy, FuncAttrs);
+    Function *LtGRFunc = emitListToGlobalReduceFunction(
+        ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
+    Function *GtLCFunc = emitGlobalToListCopyFunction(
+        ReductionInfos, ReductionsBufferTy, FuncAttrs);
+    Function *GtLRFunc = emitGlobalToListReduceFunction(
+        ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
+    Builder.restoreIP(CodeGenIP);
+
+    Value *KernelTeamsReductionPtr = Builder.CreateCall(
+        RedFixedBuferFn, {}, "_openmp_teams_reductions_buffer_$_$ptr");
+
+    Value *Args3[] = {RTLoc,
+                      KernelTeamsReductionPtr,
+                      Builder.getInt32(ReductionBufNum),
+                      ReductionDataSize,
+                      RL,
+                      SarFunc,
+                      WcFunc,
+                      LtGCFunc,
+                      LtGRFunc,
+                      GtLCFunc,
+                      GtLRFunc};
+
+    Function *TeamsReduceFn = getOrCreateRuntimeFunctionPtr(
+        RuntimeFunction::OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2);
+    Res = Builder.CreateCall(TeamsReduceFn, Args3);
+  }
+
+  // 5. Build if (res == 1)
+  BasicBlock *ExitBB = BasicBlock::Create(Ctx, ".omp.reduction.done");
+  BasicBlock *ThenBB = BasicBlock::Create(Ctx, ".omp.reduction.then");
+  Value *Cond = Builder.CreateICmpEQ(Res, Builder.getInt32(1));
+  Builder.CreateCondBr(Cond, ThenBB, ExitBB);
+
+  // 6. Build then branch: where we have reduced values in the master
+  //    thread in each team.
+  //    __kmpc_end_reduce{_nowait}(<gtid>);
+  //    break;
+  emitBlock(ThenBB, CurFunc);
+
+  // Add emission of __kmpc_end_reduce{_nowait}(<gtid>);
+  for (auto En : enumerate(ReductionInfos)) {
+    const ReductionInfo &RI = En.value();
+    Value *LHS = RI.Variable;
+    Value *RHS =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(RI.PrivateVariable, PtrTy);
+
+    if (ReductionGenCBTy == ReductionGenCBKind::Clang) {
+      Value *LHSPtr, *RHSPtr;
+      Builder.restoreIP(RI.ReductionGenClang(Builder.saveIP(), En.index(),
+                                             &LHSPtr, &RHSPtr, CurFunc));
+
+      // Fix the CallBack code genereated to use the correct Values for the LHS
+      // and RHS
+      LHSPtr->replaceUsesWithIf(LHS, [ReductionFunc](const Use &U) {
+        return cast<Instruction>(U.getUser())->getParent()->getParent() ==
+               ReductionFunc;
+      });
+      RHSPtr->replaceUsesWithIf(RHS, [ReductionFunc](const Use &U) {
+        return cast<Instruction>(U.getUser())->getParent()->getParent() ==
+               ReductionFunc;
+      });
+    } else {
+      // LHS = Builder.CreateLoad(LHS);
+      // LHS = Builder.CreateLoad(LHS);
+      // Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS));
----------------
TIFitis wrote:
Code should be added here when we add MLIR/Flang codegen support. This is unreachable for now, I've replaced the commented code with an assertion failure.
https://github.com/llvm/llvm-project/pull/80343
    
    
More information about the Mlir-commits
mailing list