[clang] [llvm] [mlir] [OpenMP] Migrate GPU Reductions CodeGen from Clang to OMPIRBuilder (PR #80343)
Akash Banerjee via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 16 07:54:42 PDT 2024
================
@@ -2096,15 +2102,1408 @@ OpenMPIRBuilder::createSection(const LocationDescription &Loc,
/*IsCancellable*/ true);
}
-/// Create a function with a unique name and a "void (i8*, i8*)" signature in
-/// the given module and return it.
-Function *getFreshReductionFunc(Module &M) {
+static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
+ BasicBlock::iterator IT(I);
+ IT++;
+ return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
+}
+
+void OpenMPIRBuilder::emitUsed(StringRef Name,
+ std::vector<WeakTrackingVH> &List) {
+ if (List.empty())
+ return;
+
+ // Convert List to what ConstantArray needs.
+ SmallVector<Constant *, 8> UsedArray;
+ UsedArray.resize(List.size());
+ for (unsigned I = 0, E = List.size(); I != E; ++I)
+ UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+ cast<Constant>(&*List[I]), Builder.getPtrTy());
+
+ if (UsedArray.empty())
+ return;
+ ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
+
+ auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
+ ConstantArray::get(ATy, UsedArray), Name);
+
+ GV->setSection("llvm.metadata");
+}
+
+Value *OpenMPIRBuilder::getGPUThreadID() {
+ return Builder.CreateCall(
+ getOrCreateRuntimeFunction(M,
+ OMPRTL___kmpc_get_hardware_thread_id_in_block),
+ {});
+}
+
+Value *OpenMPIRBuilder::getGPUWarpSize() {
+ return Builder.CreateCall(
+ getOrCreateRuntimeFunction(M, OMPRTL___kmpc_get_warp_size), {});
+}
+
+Value *OpenMPIRBuilder::getNVPTXWarpID() {
+ unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
+ return Builder.CreateAShr(getGPUThreadID(), LaneIDBits, "nvptx_warp_id");
+}
+
+Value *OpenMPIRBuilder::getNVPTXLaneID() {
+ unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
+ assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
+ unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
+ return Builder.CreateAnd(getGPUThreadID(), Builder.getInt32(LaneIDMask),
+ "nvptx_lane_id");
+}
+
+Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
+ Type *ToType) {
+ Type *FromType = From->getType();
+ uint64_t FromSize = M.getDataLayout().getTypeStoreSize(FromType);
+ uint64_t ToSize = M.getDataLayout().getTypeStoreSize(ToType);
+ assert(FromSize > 0 && "From size must be greater than zero");
+ assert(ToSize > 0 && "To size must be greater than zero");
+ if (FromType == ToType)
+ return From;
+ if (FromSize == ToSize)
+ return Builder.CreateBitCast(From, ToType);
+ if (ToType->isIntegerTy() && FromType->isIntegerTy())
+ return Builder.CreateIntCast(From, ToType, /*isSigned*/ true);
+ InsertPointTy SaveIP = Builder.saveIP();
+ Builder.restoreIP(AllocaIP);
+ Value *CastItem = Builder.CreateAlloca(ToType);
+ Builder.restoreIP(SaveIP);
+
+ Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ CastItem, FromType->getPointerTo());
+ Builder.CreateStore(From, ValCastItem);
+ return Builder.CreateLoad(ToType, CastItem);
+}
+
+Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
+ Value *Element,
+ Type *ElementType,
+ Value *Offset) {
+ uint64_t Size = M.getDataLayout().getTypeStoreSize(ElementType);
+ assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
+
+ // Cast all types to 32- or 64-bit values before calling shuffle routines.
+ Type *CastTy = Builder.getIntNTy(Size <= 4 ? 32 : 64);
+ Value *ElemCast = castValueToType(AllocaIP, Element, CastTy);
+ Value *WarpSize =
+ Builder.CreateIntCast(getGPUWarpSize(), Builder.getInt16Ty(), true);
+ Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
+ Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
+ : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
+ Value *WarpSizeCast =
+ Builder.CreateIntCast(WarpSize, Builder.getInt16Ty(), /*isSigned=*/true);
+ Value *ShuffleCall =
+ Builder.CreateCall(ShuffleFunc, {ElemCast, Offset, WarpSizeCast});
+ return castValueToType(AllocaIP, ShuffleCall, CastTy);
+}
+
+void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
+ Value *DstAddr, Type *ElemType,
+ Value *Offset, Type *ReductionArrayTy) {
+ uint64_t Size = M.getDataLayout().getTypeStoreSize(ElemType);
+ // Create the loop over the big sized data.
+ // ptr = (void*)Elem;
+ // ptrEnd = (void*) Elem + 1;
+ // Step = 8;
+ // while (ptr + Step < ptrEnd)
+ // shuffle((int64_t)*ptr);
+ // Step = 4;
+ // while (ptr + Step < ptrEnd)
+ // shuffle((int32_t)*ptr);
+ // ...
+ Type *IndexTy = Builder.getIndexTy(
+ M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+ Value *ElemPtr = DstAddr;
+ Value *Ptr = SrcAddr;
+ for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
+ if (Size < IntSize)
+ continue;
+ Type *IntType = Builder.getIntNTy(IntSize * 8);
+ Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ Ptr, IntType->getPointerTo(), Ptr->getName() + ".ascast");
+ Value *SrcAddrGEP =
+ Builder.CreateGEP(ElemType, SrcAddr, {ConstantInt::get(IndexTy, 1)});
+ ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ElemPtr, IntType->getPointerTo(), ElemPtr->getName() + ".ascast");
+
+ Function *CurFunc = Builder.GetInsertBlock()->getParent();
+ if ((Size / IntSize) > 1) {
+ Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ SrcAddrGEP, Builder.getPtrTy());
+ BasicBlock *PreCondBB =
+ BasicBlock::Create(M.getContext(), ".shuffle.pre_cond");
+ BasicBlock *ThenBB = BasicBlock::Create(M.getContext(), ".shuffle.then");
+ BasicBlock *ExitBB = BasicBlock::Create(M.getContext(), ".shuffle.exit");
+ BasicBlock *CurrentBB = Builder.GetInsertBlock();
+ emitBlock(PreCondBB, CurFunc);
+ PHINode *PhiSrc =
+ Builder.CreatePHI(Ptr->getType(), /*NumReservedValues=*/2);
+ PhiSrc->addIncoming(Ptr, CurrentBB);
+ PHINode *PhiDest =
+ Builder.CreatePHI(ElemPtr->getType(), /*NumReservedValues=*/2);
+ PhiDest->addIncoming(ElemPtr, CurrentBB);
+ Ptr = PhiSrc;
+ ElemPtr = PhiDest;
+ Value *PtrDiff = Builder.CreatePtrDiff(
+ Builder.getInt8Ty(), PtrEnd,
+ Builder.CreatePointerBitCastOrAddrSpaceCast(Ptr, Builder.getPtrTy()));
+ Builder.CreateCondBr(
+ Builder.CreateICmpSGT(PtrDiff, Builder.getInt64(IntSize - 1)), ThenBB,
+ ExitBB);
+ emitBlock(ThenBB, CurFunc);
+ Value *Res = createRuntimeShuffleFunction(
+ AllocaIP,
+ Builder.CreateAlignedLoad(
+ IntType, Ptr, M.getDataLayout().getPrefTypeAlign(ElemType)),
+ IntType, Offset);
+ Builder.CreateAlignedStore(Res, ElemPtr,
+ M.getDataLayout().getPrefTypeAlign(ElemType));
+ Value *LocalPtr =
+ Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
+ Value *LocalElemPtr =
+ Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
+ PhiSrc->addIncoming(LocalPtr, ThenBB);
+ PhiDest->addIncoming(LocalElemPtr, ThenBB);
+ emitBranch(PreCondBB);
+ emitBlock(ExitBB, CurFunc);
+ } else {
+ Value *Res = createRuntimeShuffleFunction(
+ AllocaIP, Builder.CreateLoad(IntType, Ptr), IntType, Offset);
+ if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
+ Res->getType()->getScalarSizeInBits())
+ Res = Builder.CreateTrunc(Res, ElemType);
+ Builder.CreateStore(Res, ElemPtr);
+ Ptr = Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
+ ElemPtr =
+ Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
+ }
+ Size = Size % IntSize;
+ }
+}
+
+void OpenMPIRBuilder::emitReductionListCopy(
+ InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
+ ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
+ CopyOptionsTy CopyOptions) {
+ Type *IndexTy = Builder.getIndexTy(
+ M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+ Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
+
+ // Iterates, element-by-element, through the source Reduce list and
+ // make a copy.
+ for (auto En : enumerate(ReductionInfos)) {
+ const ReductionInfo &RI = En.value();
+ Value *SrcElementAddr = nullptr;
+ Value *DestElementAddr = nullptr;
+ Value *DestElementPtrAddr = nullptr;
+ // Should we shuffle in an element from a remote lane?
+ bool ShuffleInElement = false;
+ // Set to true to update the pointer in the dest Reduce list to a
+ // newly created element.
+ bool UpdateDestListPtr = false;
+
+ // Step 1.1: Get the address for the src element in the Reduce list.
+ Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
+ ReductionArrayTy, SrcBase,
+ {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+ SrcElementAddr = Builder.CreateLoad(Builder.getPtrTy(), SrcElementPtrAddr);
+
+ // Step 1.2: Create a temporary to store the element in the destination
+ // Reduce list.
+ DestElementPtrAddr = Builder.CreateInBoundsGEP(
+ ReductionArrayTy, DestBase,
+ {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+ switch (Action) {
+ case CopyAction::RemoteLaneToThread: {
+ InsertPointTy CurIP = Builder.saveIP();
+ Builder.restoreIP(AllocaIP);
+ AllocaInst *DestAlloca = Builder.CreateAlloca(RI.ElementType, nullptr,
+ ".omp.reduction.element");
+ DestAlloca->setAlignment(
+ M.getDataLayout().getPrefTypeAlign(RI.ElementType));
+ DestElementAddr = DestAlloca;
+ DestElementAddr =
+ Builder.CreateAddrSpaceCast(DestElementAddr, Builder.getPtrTy(),
+ DestElementAddr->getName() + ".ascast");
+ Builder.restoreIP(CurIP);
+ ShuffleInElement = true;
+ UpdateDestListPtr = true;
+ break;
+ }
+ case CopyAction::ThreadCopy: {
+ DestElementAddr =
+ Builder.CreateLoad(Builder.getPtrTy(), DestElementPtrAddr);
+ break;
+ }
+ }
+
+ // Now that all active lanes have read the element in the
+ // Reduce list, shuffle over the value from the remote lane.
+ if (ShuffleInElement) {
+ shuffleAndStore(AllocaIP, SrcElementAddr, DestElementAddr, RI.ElementType,
+ RemoteLaneOffset, ReductionArrayTy);
+ } else {
+ switch (RI.EvaluationKind) {
+ case EvalKind::Scalar: {
+ Value *Elem = Builder.CreateLoad(RI.ElementType, SrcElementAddr);
+ // Store the source element value to the dest element address.
+ Builder.CreateStore(Elem, DestElementAddr);
+ break;
+ }
+ case EvalKind::Complex: {
+ Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
+ RI.ElementType, SrcElementAddr, 0, 0, ".realp");
+ Value *SrcReal = Builder.CreateLoad(
+ RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
+ Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
+ RI.ElementType, SrcElementAddr, 0, 1, ".imagp");
+ Value *SrcImg = Builder.CreateLoad(
+ RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
+
+ Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
+ RI.ElementType, DestElementAddr, 0, 0, ".realp");
+ Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
+ RI.ElementType, DestElementAddr, 0, 1, ".imagp");
+ Builder.CreateStore(SrcReal, DestRealPtr);
+ Builder.CreateStore(SrcImg, DestImgPtr);
+ break;
+ }
+ case EvalKind::Aggregate: {
+ Value *SizeVal = Builder.getInt64(
+ M.getDataLayout().getTypeStoreSize(RI.ElementType));
+ Builder.CreateMemCpy(
+ DestElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
+ SrcElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
+ SizeVal, false);
+ break;
+ }
+ };
+ }
+
+ // Step 3.1: Modify reference in dest Reduce list as needed.
+ // Modifying the reference in Reduce list to point to the newly
+ // created element. The element is live in the current function
+ // scope and that of functions it invokes (i.e., reduce_function).
+ // RemoteReduceData[i] = (void*)&RemoteElem
+ if (UpdateDestListPtr) {
+ Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ DestElementAddr, Builder.getPtrTy(),
+ DestElementAddr->getName() + ".ascast");
+ Builder.CreateStore(CastDestAddr, DestElementPtrAddr);
+ }
+ }
+}
+
+Function *OpenMPIRBuilder::emitInterWarpCopyFunction(
+ const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
+ AttributeList FuncAttrs) {
+ InsertPointTy SavedIP = Builder.saveIP();
+ LLVMContext &Ctx = M.getContext();
+ FunctionType *FuncTy = FunctionType::get(
+ Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getInt32Ty()},
+ /* IsVarArg */ false);
+ Function *WcFunc =
+ Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+ "_omp_reduction_inter_warp_copy_func", &M);
+ WcFunc->setAttributes(FuncAttrs);
+ WcFunc->addParamAttr(0, Attribute::NoUndef);
+ WcFunc->addParamAttr(1, Attribute::NoUndef);
+ BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", WcFunc);
+ Builder.SetInsertPoint(EntryBB);
+
+ // ReduceList: thread local Reduce list.
+ // At the stage of the computation when this function is called, partially
+ // aggregated values reside in the first lane of every active warp.
+ Argument *ReduceListArg = WcFunc->getArg(0);
+ // NumWarps: number of warps active in the parallel region. This could
+ // be smaller than 32 (max warps in a CTA) for partial block reduction.
+ Argument *NumWarpsArg = WcFunc->getArg(1);
+
+ // This array is used as a medium to transfer, one reduce element at a time,
+ // the data from the first lane of every warp to lanes in the first warp
+ // in order to perform the final step of a reduction in a parallel region
+ // (reduction across warps). The array is placed in NVPTX __shared__ memory
+ // for reduced latency, as well as to have a distinct copy for concurrently
+ // executing target regions. The array is declared with common linkage so
+ // as to be shared across compilation units.
+ StringRef TransferMediumName =
+ "__openmp_nvptx_data_transfer_temporary_storage";
+ GlobalVariable *TransferMedium = M.getGlobalVariable(TransferMediumName);
+ unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
+ ArrayType *ArrayTy = ArrayType::get(Builder.getInt32Ty(), WarpSize);
+ if (!TransferMedium) {
+ TransferMedium = new GlobalVariable(
+ M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
+ UndefValue::get(ArrayTy), TransferMediumName,
+ /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
+ /*AddressSpace=*/3);
+ }
+
+ uint32_t SrcLocStrSize;
+ Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+
+ // Get the CUDA thread id of the current OpenMP thread on the GPU.
+ Value *GPUThreadID = getGPUThreadID();
+ // nvptx_lane_id = nvptx_id % warpsize
+ Value *LaneID = getNVPTXLaneID();
+ // nvptx_warp_id = nvptx_id / warpsize
+ Value *WarpID = getNVPTXWarpID();
+
+ InsertPointTy AllocaIP =
+ InsertPointTy(Builder.GetInsertBlock(),
+ Builder.GetInsertBlock()->getFirstInsertionPt());
+ Type *Arg0Type = ReduceListArg->getType();
+ Type *Arg1Type = NumWarpsArg->getType();
+ Builder.restoreIP(AllocaIP);
+ AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
+ Arg0Type, nullptr, ReduceListArg->getName() + ".addr");
+ AllocaInst *NumWarpsAlloca =
+ Builder.CreateAlloca(Arg1Type, nullptr, NumWarpsArg->getName() + ".addr");
+ Value *ThreadID =
+ getOrCreateThreadID(getOrCreateIdent(SrcLocStr, SrcLocStrSize));
+ Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ReduceListAlloca, Arg0Type, ReduceListAlloca->getName() + ".ascast");
+ Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ NumWarpsAlloca, Arg1Type->getPointerTo(),
+ NumWarpsAlloca->getName() + ".ascast");
+ Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
+ Builder.CreateStore(NumWarpsArg, NumWarpsAddrCast);
+ AllocaIP = getInsertPointAfterInstr(NumWarpsAlloca);
+ InsertPointTy CodeGenIP =
+ getInsertPointAfterInstr(&Builder.GetInsertBlock()->back());
+ Builder.restoreIP(CodeGenIP);
+
+ Value *ReduceList =
+ Builder.CreateLoad(Builder.getPtrTy(), ReduceListAddrCast);
+
+ for (auto En : enumerate(ReductionInfos)) {
+ //
+ // Warp master copies reduce element to transfer medium in __shared__
+ // memory.
+ //
+ const ReductionInfo &RI = En.value();
+ unsigned RealTySize = M.getDataLayout().getTypeAllocSize(RI.ElementType);
+ for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
+ Type *CType = Builder.getIntNTy(TySize * 8);
+
+ unsigned NumIters = RealTySize / TySize;
+ if (NumIters == 0)
+ continue;
+ Value *Cnt = nullptr;
+ Value *CntAddr = nullptr;
+ BasicBlock *PrecondBB = nullptr;
+ BasicBlock *ExitBB = nullptr;
+ if (NumIters > 1) {
+ CodeGenIP = Builder.saveIP();
+ Builder.restoreIP(AllocaIP);
+ CntAddr =
+ Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, ".cnt.addr");
+
+ CntAddr = Builder.CreateAddrSpaceCast(CntAddr, Builder.getPtrTy(),
+ CntAddr->getName() + ".ascast");
+ Builder.restoreIP(CodeGenIP);
+ Builder.CreateStore(Constant::getNullValue(Builder.getInt32Ty()),
+ CntAddr,
+ /*Volatile=*/false);
+ PrecondBB = BasicBlock::Create(Ctx, "precond");
+ ExitBB = BasicBlock::Create(Ctx, "exit");
+ BasicBlock *BodyBB = BasicBlock::Create(Ctx, "body");
+ emitBlock(PrecondBB, Builder.GetInsertBlock()->getParent());
+ Cnt = Builder.CreateLoad(Builder.getInt32Ty(), CntAddr,
+ /*Volatile=*/false);
+ Value *Cmp = Builder.CreateICmpULT(
+ Cnt, ConstantInt::get(Builder.getInt32Ty(), NumIters));
+ Builder.CreateCondBr(Cmp, BodyBB, ExitBB);
+ emitBlock(BodyBB, Builder.GetInsertBlock()->getParent());
+ }
+
+ // kmpc_barrier.
+ createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
+ omp::Directive::OMPD_unknown,
+ /* ForceSimpleCall */ false,
+ /* CheckCancelFlag */ true, ThreadID);
+ BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
+ BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
+ BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
+
+ // if (lane_id == 0)
+ Value *IsWarpMaster = Builder.CreateIsNull(LaneID, "warp_master");
+ Builder.CreateCondBr(IsWarpMaster, ThenBB, ElseBB);
+ emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
+
+ // Reduce element = LocalReduceList[i]
+ auto *RedListArrayTy =
+ ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+ Type *IndexTy = Builder.getIndexTy(
+ M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+ Value *ElemPtrPtr =
+ Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
+ {ConstantInt::get(IndexTy, 0),
+ ConstantInt::get(IndexTy, En.index())});
+ // elemptr = ((CopyType*)(elemptrptr)) + I
+ Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
+ if (NumIters > 1)
+ ElemPtr = Builder.CreateGEP(Builder.getInt32Ty(), ElemPtr, Cnt);
+
+ // Get pointer to location in transfer medium.
+ // MediumPtr = &medium[warp_id]
+ Value *MediumPtr = Builder.CreateInBoundsGEP(
+ ArrayTy, TransferMedium, {Builder.getInt64(0), WarpID});
+ // elem = *elemptr
+ //*MediumPtr = elem
+ Value *Elem = Builder.CreateLoad(CType, ElemPtr);
+ // Store the source element value to the dest element address.
+ Builder.CreateStore(Elem, MediumPtr,
+ /*IsVolatile*/ true);
+ Builder.CreateBr(MergeBB);
+
+ // else
+ emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
+ Builder.CreateBr(MergeBB);
+
+ // endif
+ emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
+ createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
+ omp::Directive::OMPD_unknown,
+ /* ForceSimpleCall */ false,
+ /* CheckCancelFlag */ true, ThreadID);
+
+ // Warp 0 copies reduce element from transfer medium
+ BasicBlock *W0ThenBB = BasicBlock::Create(Ctx, "then");
+ BasicBlock *W0ElseBB = BasicBlock::Create(Ctx, "else");
+ BasicBlock *W0MergeBB = BasicBlock::Create(Ctx, "ifcont");
+
+ Value *NumWarpsVal =
+ Builder.CreateLoad(Builder.getInt32Ty(), NumWarpsAddrCast);
+ // Up to 32 threads in warp 0 are active.
+ Value *IsActiveThread =
+ Builder.CreateICmpULT(GPUThreadID, NumWarpsVal, "is_active_thread");
+ Builder.CreateCondBr(IsActiveThread, W0ThenBB, W0ElseBB);
+
+ emitBlock(W0ThenBB, Builder.GetInsertBlock()->getParent());
+
+ // SecMediumPtr = &medium[tid]
+ // SrcMediumVal = *SrcMediumPtr
+ Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
+ ArrayTy, TransferMedium, {Builder.getInt64(0), GPUThreadID});
+ // TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
+ Value *TargetElemPtrPtr =
+ Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
+ {ConstantInt::get(IndexTy, 0),
+ ConstantInt::get(IndexTy, En.index())});
+ Value *TargetElemPtrVal =
+ Builder.CreateLoad(Builder.getPtrTy(), TargetElemPtrPtr);
+ Value *TargetElemPtr = TargetElemPtrVal;
+ if (NumIters > 1)
+ TargetElemPtr =
+ Builder.CreateGEP(Builder.getInt32Ty(), TargetElemPtr, Cnt);
+
+ // *TargetElemPtr = SrcMediumVal;
+ Value *SrcMediumValue =
+ Builder.CreateLoad(CType, SrcMediumPtrVal, /*IsVolatile*/ true);
+ Builder.CreateStore(SrcMediumValue, TargetElemPtr);
+ Builder.CreateBr(W0MergeBB);
+
+ emitBlock(W0ElseBB, Builder.GetInsertBlock()->getParent());
+ Builder.CreateBr(W0MergeBB);
+
+ emitBlock(W0MergeBB, Builder.GetInsertBlock()->getParent());
+
+ if (NumIters > 1) {
+ Cnt = Builder.CreateNSWAdd(
+ Cnt, ConstantInt::get(Builder.getInt32Ty(), /*V=*/1));
+ Builder.CreateStore(Cnt, CntAddr, /*Volatile=*/false);
+
+ auto *CurFn = Builder.GetInsertBlock()->getParent();
+ emitBranch(PrecondBB);
+ emitBlock(ExitBB, CurFn);
+ }
+ RealTySize %= TySize;
+ }
+ }
+
+ Builder.CreateRetVoid();
+ Builder.restoreIP(SavedIP);
+
+ return WcFunc;
+}
+
+Function *OpenMPIRBuilder::emitShuffleAndReduceFunction(
+ ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
+ AttributeList FuncAttrs) {
+ LLVMContext &Ctx = M.getContext();
+ FunctionType *FuncTy =
+ FunctionType::get(Builder.getVoidTy(),
+ {Builder.getPtrTy(), Builder.getInt16Ty(),
+ Builder.getInt16Ty(), Builder.getInt16Ty()},
+ /* IsVarArg */ false);
+ Function *SarFunc =
+ Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+ "_omp_reduction_shuffle_and_reduce_func", &M);
+ SarFunc->setAttributes(FuncAttrs);
+ SarFunc->addParamAttr(0, Attribute::NoUndef);
+ SarFunc->addParamAttr(1, Attribute::NoUndef);
+ SarFunc->addParamAttr(2, Attribute::NoUndef);
+ SarFunc->addParamAttr(3, Attribute::NoUndef);
+ SarFunc->addParamAttr(1, Attribute::SExt);
+ SarFunc->addParamAttr(2, Attribute::SExt);
+ SarFunc->addParamAttr(3, Attribute::SExt);
+ BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", SarFunc);
+ Builder.SetInsertPoint(EntryBB);
+
+ // Thread local Reduce list used to host the values of data to be reduced.
+ Argument *ReduceListArg = SarFunc->getArg(0);
+ // Current lane id; could be logical.
+ Argument *LaneIDArg = SarFunc->getArg(1);
+ // Offset of the remote source lane relative to the current lane.
+ Argument *RemoteLaneOffsetArg = SarFunc->getArg(2);
+ // Algorithm version. This is expected to be known at compile time.
+ Argument *AlgoVerArg = SarFunc->getArg(3);
+
+ Type *ReduceListArgType = ReduceListArg->getType();
+ Type *LaneIDArgType = LaneIDArg->getType();
+ Type *LaneIDArgPtrType = LaneIDArg->getType()->getPointerTo();
+ Value *ReduceListAlloca = Builder.CreateAlloca(
+ ReduceListArgType, nullptr, ReduceListArg->getName() + ".addr");
+ Value *LaneIdAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
+ LaneIDArg->getName() + ".addr");
+ Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
+ LaneIDArgType, nullptr, RemoteLaneOffsetArg->getName() + ".addr");
+ Value *AlgoVerAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
+ AlgoVerArg->getName() + ".addr");
+ ArrayType *RedListArrayTy =
+ ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+
+ // Create a local thread-private variable to host the Reduce list
+ // from a remote lane.
+ Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
+ RedListArrayTy, nullptr, ".omp.reduction.remote_reduce_list");
+
+ Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ReduceListAlloca, ReduceListArgType,
+ ReduceListAlloca->getName() + ".ascast");
+ Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ LaneIdAlloca, LaneIDArgPtrType, LaneIdAlloca->getName() + ".ascast");
+ Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ RemoteLaneOffsetAlloca, LaneIDArgPtrType,
+ RemoteLaneOffsetAlloca->getName() + ".ascast");
+ Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ AlgoVerAlloca, LaneIDArgPtrType, AlgoVerAlloca->getName() + ".ascast");
+ Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ RemoteReductionListAlloca, Builder.getPtrTy(),
+ RemoteReductionListAlloca->getName() + ".ascast");
+
+ Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
+ Builder.CreateStore(LaneIDArg, LaneIdAddrCast);
+ Builder.CreateStore(RemoteLaneOffsetArg, RemoteLaneOffsetAddrCast);
+ Builder.CreateStore(AlgoVerArg, AlgoVerAddrCast);
+
+ Value *ReduceList = Builder.CreateLoad(ReduceListArgType, ReduceListAddrCast);
+ Value *LaneId = Builder.CreateLoad(LaneIDArgType, LaneIdAddrCast);
+ Value *RemoteLaneOffset =
+ Builder.CreateLoad(LaneIDArgType, RemoteLaneOffsetAddrCast);
+ Value *AlgoVer = Builder.CreateLoad(LaneIDArgType, AlgoVerAddrCast);
+
+ InsertPointTy AllocaIP = getInsertPointAfterInstr(RemoteReductionListAlloca);
+
+ // This loop iterates through the list of reduce elements and copies,
+ // element by element, from a remote lane in the warp to RemoteReduceList,
+ // hosted on the thread's stack.
+ emitReductionListCopy(
+ AllocaIP, CopyAction::RemoteLaneToThread, RedListArrayTy, ReductionInfos,
+ ReduceList, RemoteListAddrCast, {RemoteLaneOffset, nullptr, nullptr});
+
+ // The actions to be performed on the Remote Reduce list is dependent
+ // on the algorithm version.
+ //
+ // if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
+ // LaneId % 2 == 0 && Offset > 0):
+ // do the reduction value aggregation
+ //
+ // The thread local variable Reduce list is mutated in place to host the
+ // reduced data, which is the aggregated value produced from local and
+ // remote lanes.
+ //
+ // Note that AlgoVer is expected to be a constant integer known at compile
+ // time.
+ // When AlgoVer==0, the first conjunction evaluates to true, making
+ // the entire predicate true during compile time.
+ // When AlgoVer==1, the second conjunction has only the second part to be
+ // evaluated during runtime. Other conjunctions evaluates to false
+ // during compile time.
+ // When AlgoVer==2, the third conjunction has only the second part to be
+ // evaluated during runtime. Other conjunctions evaluates to false
+ // during compile time.
+ Value *CondAlgo0 = Builder.CreateIsNull(AlgoVer);
+ Value *Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
+ Value *LaneComp = Builder.CreateICmpULT(LaneId, RemoteLaneOffset);
+ Value *CondAlgo1 = Builder.CreateAnd(Algo1, LaneComp);
+ Value *Algo2 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(2));
+ Value *LaneIdAnd1 = Builder.CreateAnd(LaneId, Builder.getInt16(1));
+ Value *LaneIdComp = Builder.CreateIsNull(LaneIdAnd1);
+ Value *Algo2AndLaneIdComp = Builder.CreateAnd(Algo2, LaneIdComp);
+ Value *RemoteOffsetComp =
+ Builder.CreateICmpSGT(RemoteLaneOffset, Builder.getInt16(0));
+ Value *CondAlgo2 = Builder.CreateAnd(Algo2AndLaneIdComp, RemoteOffsetComp);
+ Value *CA0OrCA1 = Builder.CreateOr(CondAlgo0, CondAlgo1);
+ Value *CondReduce = Builder.CreateOr(CA0OrCA1, CondAlgo2);
+
+ BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
+ BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
+ BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
+
+ Builder.CreateCondBr(CondReduce, ThenBB, ElseBB);
+ emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
+ // reduce_function(LocalReduceList, RemoteReduceList)
+ Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ReduceList, Builder.getPtrTy());
+ Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ RemoteListAddrCast, Builder.getPtrTy());
+ Builder.CreateCall(ReduceFn, {LocalReduceListPtr, RemoteReduceListPtr})
+ ->addFnAttr(Attribute::NoUnwind);
+ Builder.CreateBr(MergeBB);
+
+ emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
+ Builder.CreateBr(MergeBB);
+
+ emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
+
+ // if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
+ // Reduce list.
+ Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
+ Value *LaneIdGtOffset = Builder.CreateICmpUGE(LaneId, RemoteLaneOffset);
+ Value *CondCopy = Builder.CreateAnd(Algo1, LaneIdGtOffset);
+
+ BasicBlock *CpyThenBB = BasicBlock::Create(Ctx, "then");
+ BasicBlock *CpyElseBB = BasicBlock::Create(Ctx, "else");
+ BasicBlock *CpyMergeBB = BasicBlock::Create(Ctx, "ifcont");
+ Builder.CreateCondBr(CondCopy, CpyThenBB, CpyElseBB);
+
+ emitBlock(CpyThenBB, Builder.GetInsertBlock()->getParent());
+ emitReductionListCopy(AllocaIP, CopyAction::ThreadCopy, RedListArrayTy,
+ ReductionInfos, RemoteListAddrCast, ReduceList);
+ Builder.CreateBr(CpyMergeBB);
+
+ emitBlock(CpyElseBB, Builder.GetInsertBlock()->getParent());
+ Builder.CreateBr(CpyMergeBB);
+
+ emitBlock(CpyMergeBB, Builder.GetInsertBlock()->getParent());
+
+ Builder.CreateRetVoid();
+
+ return SarFunc;
+}
+
+Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
+ ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
+ AttributeList FuncAttrs) {
+ OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
+ LLVMContext &Ctx = M.getContext();
+ FunctionType *FuncTy = FunctionType::get(
+ Builder.getVoidTy(),
+ {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
+ /* IsVarArg */ false);
+ Function *LtGCFunc =
+ Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+ "_omp_reduction_list_to_global_copy_func", &M);
+ LtGCFunc->setAttributes(FuncAttrs);
+ LtGCFunc->addParamAttr(0, Attribute::NoUndef);
+ LtGCFunc->addParamAttr(1, Attribute::NoUndef);
+ LtGCFunc->addParamAttr(2, Attribute::NoUndef);
+
+ BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
+ Builder.SetInsertPoint(EntryBlock);
+
+ // Buffer: global reduction buffer.
+ Argument *BufferArg = LtGCFunc->getArg(0);
+ // Idx: index of the buffer.
+ Argument *IdxArg = LtGCFunc->getArg(1);
+ // ReduceList: thread local Reduce list.
+ Argument *ReduceListArg = LtGCFunc->getArg(2);
+
+ Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
+ BufferArg->getName() + ".addr");
+ Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
+ IdxArg->getName() + ".addr");
+ Value *ReduceListArgAlloca = Builder.CreateAlloca(
+ Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
+ Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ BufferArgAlloca, Builder.getPtrTy(),
+ BufferArgAlloca->getName() + ".ascast");
+ Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
+ Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ReduceListArgAlloca, Builder.getPtrTy(),
+ ReduceListArgAlloca->getName() + ".ascast");
+
+ Builder.CreateStore(BufferArg, BufferArgAddrCast);
+ Builder.CreateStore(IdxArg, IdxArgAddrCast);
+ Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
+
+ Value *LocalReduceList =
+ Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+ Value *BufferArgVal =
+ Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
+ Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
+ Type *IndexTy = Builder.getIndexTy(
+ M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+ for (auto En : enumerate(ReductionInfos)) {
+ const ReductionInfo &RI = En.value();
+ auto *RedListArrayTy =
+ ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+ // Reduce element = LocalReduceList[i]
+ Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
+ RedListArrayTy, LocalReduceList,
+ {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+ // elemptr = ((CopyType*)(elemptrptr)) + I
+ Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
+
+ // Global = Buffer.VD[Idx];
+ Value *BufferVD =
+ Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferArgVal, Idxs);
+ Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
+ ReductionsBufferTy, BufferVD, 0, En.index(), "sum");
+
+ switch (RI.EvaluationKind) {
+ case EvalKind::Scalar: {
+ Value *TargetElement = Builder.CreateLoad(RI.ElementType, ElemPtr);
+ Builder.CreateStore(TargetElement, GlobVal);
+ break;
+ }
+ case EvalKind::Complex: {
+ Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
+ RI.ElementType, ElemPtr, 0, 0, ".realp");
+ Value *SrcReal = Builder.CreateLoad(
+ RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
+ Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
+ RI.ElementType, ElemPtr, 0, 1, ".imagp");
+ Value *SrcImg = Builder.CreateLoad(
+ RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
+
+ Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
+ RI.ElementType, GlobVal, 0, 0, ".realp");
+ Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
+ RI.ElementType, GlobVal, 0, 1, ".imagp");
+ Builder.CreateStore(SrcReal, DestRealPtr);
+ Builder.CreateStore(SrcImg, DestImgPtr);
+ break;
+ }
+ case EvalKind::Aggregate: {
+ Value *SizeVal =
+ Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
+ Builder.CreateMemCpy(
+ GlobVal, M.getDataLayout().getPrefTypeAlign(RI.ElementType), ElemPtr,
+ M.getDataLayout().getPrefTypeAlign(RI.ElementType), SizeVal, false);
+ break;
+ }
+ }
+ }
+
+ Builder.CreateRetVoid();
+ Builder.restoreIP(OldIP);
+ return LtGCFunc;
+}
+
+Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
+ ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
+ Type *ReductionsBufferTy, AttributeList FuncAttrs) {
+ OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
+ LLVMContext &Ctx = M.getContext();
+ FunctionType *FuncTy = FunctionType::get(
+ Builder.getVoidTy(),
+ {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
+ /* IsVarArg */ false);
+ Function *LtGRFunc =
+ Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+ "_omp_reduction_list_to_global_reduce_func", &M);
+ LtGRFunc->setAttributes(FuncAttrs);
+ LtGRFunc->addParamAttr(0, Attribute::NoUndef);
+ LtGRFunc->addParamAttr(1, Attribute::NoUndef);
+ LtGRFunc->addParamAttr(2, Attribute::NoUndef);
+
+ BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
+ Builder.SetInsertPoint(EntryBlock);
+
+ // Buffer: global reduction buffer.
+ Argument *BufferArg = LtGRFunc->getArg(0);
+ // Idx: index of the buffer.
+ Argument *IdxArg = LtGRFunc->getArg(1);
+ // ReduceList: thread local Reduce list.
+ Argument *ReduceListArg = LtGRFunc->getArg(2);
+
+ Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
+ BufferArg->getName() + ".addr");
+ Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
+ IdxArg->getName() + ".addr");
+ Value *ReduceListArgAlloca = Builder.CreateAlloca(
+ Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
+ auto *RedListArrayTy =
+ ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+
+ // 1. Build a list of reduction variables.
+ // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
+ Value *LocalReduceList =
+ Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
+
+ Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ BufferArgAlloca, Builder.getPtrTy(),
+ BufferArgAlloca->getName() + ".ascast");
+ Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
+ Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ReduceListArgAlloca, Builder.getPtrTy(),
+ ReduceListArgAlloca->getName() + ".ascast");
+ Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ LocalReduceList, Builder.getPtrTy(),
+ LocalReduceList->getName() + ".ascast");
+
+ Builder.CreateStore(BufferArg, BufferArgAddrCast);
+ Builder.CreateStore(IdxArg, IdxArgAddrCast);
+ Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
+
+ Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
+ Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
+ Type *IndexTy = Builder.getIndexTy(
+ M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+ for (auto En : enumerate(ReductionInfos)) {
+ Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
+ RedListArrayTy, LocalReduceListAddrCast,
+ {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+ Value *BufferVD =
+ Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
+ // Global = Buffer.VD[Idx];
+ Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
+ ReductionsBufferTy, BufferVD, 0, En.index(), "sum");
+ Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+ }
+
+ // Call reduce_function(GlobalReduceList, ReduceList)
+ Value *ReduceList =
+ Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+ Builder.CreateCall(ReduceFn, {LocalReduceListAddrCast, ReduceList})
+ ->addFnAttr(Attribute::NoUnwind);
+ Builder.CreateRetVoid();
+ Builder.restoreIP(OldIP);
+ return LtGRFunc;
+}
+
+Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
+ ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
+ AttributeList FuncAttrs) {
+ OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
+ LLVMContext &Ctx = M.getContext();
+ FunctionType *FuncTy = FunctionType::get(
+ Builder.getVoidTy(),
+ {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
+ /* IsVarArg */ false);
+ Function *LtGCFunc =
+ Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+ "_omp_reduction_global_to_list_copy_func", &M);
+ LtGCFunc->setAttributes(FuncAttrs);
+ LtGCFunc->addParamAttr(0, Attribute::NoUndef);
+ LtGCFunc->addParamAttr(1, Attribute::NoUndef);
+ LtGCFunc->addParamAttr(2, Attribute::NoUndef);
+
+ BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
+ Builder.SetInsertPoint(EntryBlock);
+
+ // Buffer: global reduction buffer.
+ Argument *BufferArg = LtGCFunc->getArg(0);
+ // Idx: index of the buffer.
+ Argument *IdxArg = LtGCFunc->getArg(1);
+ // ReduceList: thread local Reduce list.
+ Argument *ReduceListArg = LtGCFunc->getArg(2);
+
+ Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
+ BufferArg->getName() + ".addr");
+ Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
+ IdxArg->getName() + ".addr");
+ Value *ReduceListArgAlloca = Builder.CreateAlloca(
+ Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
+ Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ BufferArgAlloca, Builder.getPtrTy(),
+ BufferArgAlloca->getName() + ".ascast");
+ Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
+ Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ReduceListArgAlloca, Builder.getPtrTy(),
+ ReduceListArgAlloca->getName() + ".ascast");
+ Builder.CreateStore(BufferArg, BufferArgAddrCast);
+ Builder.CreateStore(IdxArg, IdxArgAddrCast);
+ Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
+
+ Value *LocalReduceList =
+ Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+ Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
+ Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
+ Type *IndexTy = Builder.getIndexTy(
+ M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+ for (auto En : enumerate(ReductionInfos)) {
+ const OpenMPIRBuilder::ReductionInfo &RI = En.value();
+ auto *RedListArrayTy =
+ ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+ // Reduce element = LocalReduceList[i]
+ Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
+ RedListArrayTy, LocalReduceList,
+ {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+ // elemptr = ((CopyType*)(elemptrptr)) + I
+ Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
+ // Global = Buffer.VD[Idx];
+ Value *BufferVD =
+ Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
+ Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
+ ReductionsBufferTy, BufferVD, 0, En.index(), "sum");
+
+ switch (RI.EvaluationKind) {
+ case EvalKind::Scalar: {
+ Value *TargetElement = Builder.CreateLoad(RI.ElementType, GlobValPtr);
+ Builder.CreateStore(TargetElement, ElemPtr);
+ break;
+ }
+ case EvalKind::Complex: {
+ Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
+ RI.ElementType, GlobValPtr, 0, 0, ".realp");
+ Value *SrcReal = Builder.CreateLoad(
+ RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
+ Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
+ RI.ElementType, GlobValPtr, 0, 1, ".imagp");
+ Value *SrcImg = Builder.CreateLoad(
+ RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
+
+ Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
+ RI.ElementType, ElemPtr, 0, 0, ".realp");
+ Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
+ RI.ElementType, ElemPtr, 0, 1, ".imagp");
+ Builder.CreateStore(SrcReal, DestRealPtr);
+ Builder.CreateStore(SrcImg, DestImgPtr);
+ break;
+ }
+ case EvalKind::Aggregate: {
+ Value *SizeVal =
+ Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
+ Builder.CreateMemCpy(
+ ElemPtr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
+ GlobValPtr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
+ SizeVal, false);
+ break;
+ }
+ }
+ }
+
+ Builder.CreateRetVoid();
+ Builder.restoreIP(OldIP);
+ return LtGCFunc;
+}
+
+Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
+ ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
+ Type *ReductionsBufferTy, AttributeList FuncAttrs) {
+ OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
+ LLVMContext &Ctx = M.getContext();
+ auto *FuncTy = FunctionType::get(
+ Builder.getVoidTy(),
+ {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
+ /* IsVarArg */ false);
+ Function *LtGRFunc =
+ Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+ "_omp_reduction_global_to_list_reduce_func", &M);
+ LtGRFunc->setAttributes(FuncAttrs);
+ LtGRFunc->addParamAttr(0, Attribute::NoUndef);
+ LtGRFunc->addParamAttr(1, Attribute::NoUndef);
+ LtGRFunc->addParamAttr(2, Attribute::NoUndef);
+
+ BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
+ Builder.SetInsertPoint(EntryBlock);
+
+ // Buffer: global reduction buffer.
+ Argument *BufferArg = LtGRFunc->getArg(0);
+ // Idx: index of the buffer.
+ Argument *IdxArg = LtGRFunc->getArg(1);
+ // ReduceList: thread local Reduce list.
+ Argument *ReduceListArg = LtGRFunc->getArg(2);
+
+ Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
+ BufferArg->getName() + ".addr");
+ Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
+ IdxArg->getName() + ".addr");
+ Value *ReduceListArgAlloca = Builder.CreateAlloca(
+ Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
+ ArrayType *RedListArrayTy =
+ ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+
+ // 1. Build a list of reduction variables.
+ // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
+ Value *LocalReduceList =
+ Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
+
+ Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ BufferArgAlloca, Builder.getPtrTy(),
+ BufferArgAlloca->getName() + ".ascast");
+ Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
+ Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ReduceListArgAlloca, Builder.getPtrTy(),
+ ReduceListArgAlloca->getName() + ".ascast");
+ Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ LocalReduceList, Builder.getPtrTy(),
+ LocalReduceList->getName() + ".ascast");
+
+ Builder.CreateStore(BufferArg, BufferArgAddrCast);
+ Builder.CreateStore(IdxArg, IdxArgAddrCast);
+ Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
+
+ Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
+ Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
+ Type *IndexTy = Builder.getIndexTy(
+ M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+ for (auto En : enumerate(ReductionInfos)) {
+ Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
+ RedListArrayTy, ReductionList,
+ {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+ // Global = Buffer.VD[Idx];
+ Value *BufferVD =
+ Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
+ Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
+ ReductionsBufferTy, BufferVD, 0, En.index(), "sum");
+ Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+ }
+
+ // Call reduce_function(ReduceList, GlobalReduceList)
+ Value *ReduceList =
+ Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+ Builder.CreateCall(ReduceFn, {ReduceList, ReductionList})
+ ->addFnAttr(Attribute::NoUnwind);
+ Builder.CreateRetVoid();
+ Builder.restoreIP(OldIP);
+ return LtGRFunc;
+}
+
+std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
+ std::string Suffix =
+ createPlatformSpecificName({"omp", "reduction", "reduction_func"});
+ return (Name + Suffix).str();
+}
+
+Function *OpenMPIRBuilder::createReductionFunction(
+ StringRef ReducerName, ArrayRef<ReductionInfo> ReductionInfos,
+ ReductionGenCBKind ReductionGenCBTy, AttributeList FuncAttrs) {
+ auto *FuncTy = FunctionType::get(Builder.getVoidTy(),
+ {Builder.getPtrTy(), Builder.getPtrTy()},
+ /* IsVarArg */ false);
+ std::string Name = getReductionFuncName(ReducerName);
+ Function *ReductionFunc =
+ Function::Create(FuncTy, GlobalVariable::InternalLinkage, Name, &M);
+ ReductionFunc->setAttributes(FuncAttrs);
+ ReductionFunc->addParamAttr(0, Attribute::NoUndef);
+ ReductionFunc->addParamAttr(1, Attribute::NoUndef);
+ BasicBlock *EntryBB =
+ BasicBlock::Create(M.getContext(), "entry", ReductionFunc);
+ Builder.SetInsertPoint(EntryBB);
+
+ // Need to alloca memory here and deal with the pointers before getting
+ // LHS/RHS pointers out
+ Value *LHSArrayPtr = nullptr;
+ Value *RHSArrayPtr = nullptr;
+ Argument *Arg0 = ReductionFunc->getArg(0);
+ Argument *Arg1 = ReductionFunc->getArg(1);
+ Type *Arg0Type = Arg0->getType();
+ Type *Arg1Type = Arg1->getType();
+
+ Value *LHSAlloca =
+ Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
+ Value *RHSAlloca =
+ Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
+ Value *LHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ LHSAlloca, Arg0Type, LHSAlloca->getName() + ".ascast");
+ Value *RHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ RHSAlloca, Arg1Type, RHSAlloca->getName() + ".ascast");
+ Builder.CreateStore(Arg0, LHSAddrCast);
+ Builder.CreateStore(Arg1, RHSAddrCast);
+ LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
+ RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
+
+ Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+ Type *IndexTy = Builder.getIndexTy(
+ M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+ SmallVector<Value *> LHSPtrs, RHSPtrs;
+ for (auto En : enumerate(ReductionInfos)) {
+ const ReductionInfo &RI = En.value();
+ Value *RHSI8PtrPtr = Builder.CreateInBoundsGEP(
+ RedArrayTy, RHSArrayPtr,
+ {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+ Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
+ Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ RHSI8Ptr, RI.PrivateVariable->getType(),
+ RHSI8Ptr->getName() + ".ascast");
+
+ Value *LHSI8PtrPtr = Builder.CreateInBoundsGEP(
+ RedArrayTy, LHSArrayPtr,
+ {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+ Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
+ Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ LHSI8Ptr, RI.Variable->getType(), LHSI8Ptr->getName() + ".ascast");
+
+ if (ReductionGenCBTy == ReductionGenCBKind::Clang) {
+ LHSPtrs.emplace_back(LHSPtr);
+ RHSPtrs.emplace_back(RHSPtr);
+ } else {
+ Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
+ Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
+ Value *Reduced;
+ RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
+ if (!Builder.GetInsertBlock())
+ return ReductionFunc;
+ Builder.CreateStore(Reduced, LHSPtr);
+ }
+ }
+
+ if (ReductionGenCBTy == ReductionGenCBKind::Clang)
+ for (auto En : enumerate(ReductionInfos)) {
+ unsigned Index = En.index();
+ const ReductionInfo &RI = En.value();
+ Value *LHSFixupPtr, *RHSFixupPtr;
+ Builder.restoreIP(RI.ReductionGenClang(
+ Builder.saveIP(), Index, &LHSFixupPtr, &RHSFixupPtr, ReductionFunc));
+
+ // Fix the CallBack code genereated to use the correct Values for the LHS
+ // and RHS
+ LHSFixupPtr->replaceUsesWithIf(
+ LHSPtrs[Index], [ReductionFunc](const Use &U) {
+ return cast<Instruction>(U.getUser())->getParent()->getParent() ==
+ ReductionFunc;
+ });
+ RHSFixupPtr->replaceUsesWithIf(
+ RHSPtrs[Index], [ReductionFunc](const Use &U) {
+ return cast<Instruction>(U.getUser())->getParent()->getParent() ==
+ ReductionFunc;
+ });
+ }
+
+ Builder.CreateRetVoid();
+ return ReductionFunc;
+}
+
+static void
+checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
+ bool IsGPU) {
+ for (const OpenMPIRBuilder::ReductionInfo &RI : ReductionInfos) {
+ (void)RI;
+ assert(RI.Variable && "expected non-null variable");
+ assert(RI.PrivateVariable && "expected non-null private variable");
+ assert((RI.ReductionGen || RI.ReductionGenClang) &&
+ "expected non-null reduction generator callback");
+ if (!IsGPU) {
+ assert(
+ RI.Variable->getType() == RI.PrivateVariable->getType() &&
+ "expected variables and their private equivalents to have the same "
+ "type");
+ }
+ assert(RI.Variable->getType()->isPointerTy() &&
+ "expected variables to be pointers");
+ }
+}
+
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductionsGPU(
+ const LocationDescription &Loc, InsertPointTy AllocaIP,
+ InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
+ bool IsNoWait, bool IsTeamsReduction, bool HasDistribute,
+ ReductionGenCBKind ReductionGenCBTy, std::optional<omp::GV> GridValue,
+ unsigned ReductionBufNum, Value *SrcLocInfo) {
+ if (!updateToLocation(Loc))
+ return InsertPointTy();
+ Builder.restoreIP(CodeGenIP);
+ checkReductionInfos(ReductionInfos, /*IsGPU*/ true);
+ LLVMContext &Ctx = M.getContext();
+
+ // Source location for the ident struct
+ if (!SrcLocInfo) {
+ uint32_t SrcLocStrSize;
+ Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+ SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+ }
+
+ if (ReductionInfos.size() == 0)
+ return Builder.saveIP();
+
+ Function *CurFunc = Builder.GetInsertBlock()->getParent();
+ AttributeList FuncAttrs;
+ AttrBuilder AttrBldr(Ctx);
+ for (auto Attr : CurFunc->getAttributes().getFnAttrs())
+ AttrBldr.addAttribute(Attr);
+ AttrBldr.removeAttribute(Attribute::OptimizeNone);
+ FuncAttrs = FuncAttrs.addFnAttributes(Ctx, AttrBldr);
+
+ Function *ReductionFunc = nullptr;
+ if (GLOBAL_ReductionFunc) {
+ ReductionFunc = GLOBAL_ReductionFunc;
+ } else {
+ CodeGenIP = Builder.saveIP();
+ ReductionFunc = createReductionFunction(
+ Builder.GetInsertBlock()->getParent()->getName(), ReductionInfos,
+ ReductionGenCBTy, FuncAttrs);
+ Builder.restoreIP(CodeGenIP);
+ }
+
+ // Set the grid value in the config needed for lowering later on
+ if (GridValue.has_value())
+ Config.setGridValue(GridValue.value());
+ else
+ Config.setGridValue(getGridValue(T, ReductionFunc));
+
+ uint32_t SrcLocStrSize;
+ Constant *SrcLocStr = getOrCreateDefaultSrcLocStr(SrcLocStrSize);
+ Value *RTLoc =
+ getOrCreateIdent(SrcLocStr, SrcLocStrSize, omp::IdentFlag(0), 0);
+
+ // Build res = __kmpc_reduce{_nowait}(<gtid>, <n>, sizeof(RedList),
+ // RedList, shuffle_reduce_func, interwarp_copy_func);
+ // or
+ // Build res = __kmpc_reduce_teams_nowait_simple(<loc>, <gtid>, <lck>);
+ Value *Res;
+
+ // 1. Build a list of reduction variables.
+ // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
+ auto Size = ReductionInfos.size();
+ Type *PtrTy = PointerType::getUnqual(Ctx);
+ Type *RedArrayTy = ArrayType::get(PtrTy, Size);
+ CodeGenIP = Builder.saveIP();
+ Builder.restoreIP(AllocaIP);
+ Value *ReductionListAlloca =
+ Builder.CreateAlloca(RedArrayTy, nullptr, ".omp.reduction.red_list");
+ Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ReductionListAlloca, PtrTy, ReductionListAlloca->getName() + ".ascast");
+ Builder.restoreIP(CodeGenIP);
+ Type *IndexTy = Builder.getIndexTy(
+ M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+ for (auto En : enumerate(ReductionInfos)) {
+ const ReductionInfo &RI = En.value();
+ Value *ElemPtr = Builder.CreateInBoundsGEP(
+ RedArrayTy, ReductionList,
+ {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+ Value *CastElem =
+ Builder.CreatePointerBitCastOrAddrSpaceCast(RI.PrivateVariable, PtrTy);
+ Builder.CreateStore(CastElem, ElemPtr);
+ }
+ CodeGenIP = Builder.saveIP();
+ Function *SarFunc =
+ emitShuffleAndReduceFunction(ReductionInfos, ReductionFunc, FuncAttrs);
+ Function *WcFunc = emitInterWarpCopyFunction(Loc, ReductionInfos, FuncAttrs);
+ Builder.restoreIP(CodeGenIP);
+
+ Value *RL = Builder.CreatePointerBitCastOrAddrSpaceCast(ReductionList, PtrTy);
+
+ unsigned MaxDataSize = 0;
+ SmallVector<Type *> ReductionTypeArgs;
+ for (auto En : enumerate(ReductionInfos)) {
+ auto Size = M.getDataLayout().getTypeStoreSize(En.value().ElementType);
+ if (Size > MaxDataSize)
+ MaxDataSize = Size;
+ ReductionTypeArgs.emplace_back(En.value().ElementType);
+ }
+ Value *ReductionDataSize =
+ Builder.getInt64(MaxDataSize * ReductionInfos.size());
+ if (!IsTeamsReduction) {
+ Value *SarFuncCast =
+ Builder.CreatePointerBitCastOrAddrSpaceCast(SarFunc, PtrTy);
+ Value *WcFuncCast =
+ Builder.CreatePointerBitCastOrAddrSpaceCast(WcFunc, PtrTy);
+ Value *Args[] = {RTLoc, ReductionDataSize, RL, SarFuncCast, WcFuncCast};
+ Function *Pv2Ptr = getOrCreateRuntimeFunctionPtr(
+ RuntimeFunction::OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2);
+ Res = Builder.CreateCall(Pv2Ptr, Args);
+ } else {
+ CodeGenIP = Builder.saveIP();
+ StructType *ReductionsBufferTy = StructType::create(
+ Ctx, ReductionTypeArgs, "struct._globalized_locals_ty");
+ Function *RedFixedBuferFn = getOrCreateRuntimeFunctionPtr(
+ RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
+ Function *LtGCFunc = emitListToGlobalCopyFunction(
+ ReductionInfos, ReductionsBufferTy, FuncAttrs);
+ Function *LtGRFunc = emitListToGlobalReduceFunction(
+ ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
+ Function *GtLCFunc = emitGlobalToListCopyFunction(
+ ReductionInfos, ReductionsBufferTy, FuncAttrs);
+ Function *GtLRFunc = emitGlobalToListReduceFunction(
+ ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
+ Builder.restoreIP(CodeGenIP);
+
+ Value *KernelTeamsReductionPtr = Builder.CreateCall(
+ RedFixedBuferFn, {}, "_openmp_teams_reductions_buffer_$_$ptr");
+
+ Value *Args3[] = {RTLoc,
+ KernelTeamsReductionPtr,
+ Builder.getInt32(ReductionBufNum),
+ ReductionDataSize,
+ RL,
+ SarFunc,
+ WcFunc,
+ LtGCFunc,
+ LtGRFunc,
+ GtLCFunc,
+ GtLRFunc};
+
+ Function *TeamsReduceFn = getOrCreateRuntimeFunctionPtr(
+ RuntimeFunction::OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2);
+ Res = Builder.CreateCall(TeamsReduceFn, Args3);
+ }
+
+ // 5. Build if (res == 1)
+ BasicBlock *ExitBB = BasicBlock::Create(Ctx, ".omp.reduction.done");
+ BasicBlock *ThenBB = BasicBlock::Create(Ctx, ".omp.reduction.then");
+ Value *Cond = Builder.CreateICmpEQ(Res, Builder.getInt32(1));
+ Builder.CreateCondBr(Cond, ThenBB, ExitBB);
+
+ // 6. Build then branch: where we have reduced values in the master
+ // thread in each team.
+ // __kmpc_end_reduce{_nowait}(<gtid>);
+ // break;
+ emitBlock(ThenBB, CurFunc);
+
+ // Add emission of __kmpc_end_reduce{_nowait}(<gtid>);
+ for (auto En : enumerate(ReductionInfos)) {
+ const ReductionInfo &RI = En.value();
+ Value *LHS = RI.Variable;
+ Value *RHS =
+ Builder.CreatePointerBitCastOrAddrSpaceCast(RI.PrivateVariable, PtrTy);
+
+ if (ReductionGenCBTy == ReductionGenCBKind::Clang) {
+ Value *LHSPtr, *RHSPtr;
+ Builder.restoreIP(RI.ReductionGenClang(Builder.saveIP(), En.index(),
+ &LHSPtr, &RHSPtr, CurFunc));
+
+ // Fix the CallBack code genereated to use the correct Values for the LHS
+ // and RHS
+ LHSPtr->replaceUsesWithIf(LHS, [ReductionFunc](const Use &U) {
+ return cast<Instruction>(U.getUser())->getParent()->getParent() ==
+ ReductionFunc;
+ });
+ RHSPtr->replaceUsesWithIf(RHS, [ReductionFunc](const Use &U) {
+ return cast<Instruction>(U.getUser())->getParent()->getParent() ==
+ ReductionFunc;
+ });
+ } else {
+ // LHS = Builder.CreateLoad(LHS);
+ // LHS = Builder.CreateLoad(LHS);
+ // Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS));
----------------
TIFitis wrote:
Code should be added here when we add MLIR/Flang codegen support. This is unreachable for now, I've replaced the commented code with an assertion failure.
https://github.com/llvm/llvm-project/pull/80343
More information about the llvm-commits
mailing list