[Mlir-commits] [clang] [llvm] [mlir] [OpenMP] Migrate GPU Reductions CodeGen from Clang to OMPIRBuilder (PR #80343)
Akash Banerjee
llvmlistbot at llvm.org
Wed Feb 7 06:29:36 PST 2024
================
@@ -2042,35 +2057,1378 @@ OpenMPIRBuilder::createSection(const LocationDescription &Loc,
/*IsCancellable*/ true);
}
-/// Create a function with a unique name and a "void (i8*, i8*)" signature in
-/// the given module and return it.
-Function *getFreshReductionFunc(Module &M) {
- Type *VoidTy = Type::getVoidTy(M.getContext());
- Type *Int8PtrTy = PointerType::getUnqual(M.getContext());
- auto *FuncTy =
- FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false);
- return Function::Create(FuncTy, GlobalVariable::InternalLinkage,
- M.getDataLayout().getDefaultGlobalsAddressSpace(),
- ".omp.reduction.func", &M);
+static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
+ BasicBlock::iterator IT(I);
+ IT++;
+ return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
}
-OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
- const LocationDescription &Loc, InsertPointTy AllocaIP,
- ArrayRef<ReductionInfo> ReductionInfos, bool IsNoWait) {
- for (const ReductionInfo &RI : ReductionInfos) {
+void OpenMPIRBuilder::emitUsed(StringRef Name,
+ std::vector<WeakTrackingVH> &List) {
+ if (List.empty())
+ return;
+
+ // Convert List to what ConstantArray needs.
+ SmallVector<Constant *, 8> UsedArray;
+ UsedArray.resize(List.size());
+ for (unsigned I = 0, E = List.size(); I != E; ++I)
+ UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+ cast<Constant>(&*List[I]), Builder.getPtrTy());
+
+ if (UsedArray.empty())
+ return;
+ ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
+
+ auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
+ ConstantArray::get(ATy, UsedArray), Name);
+
+ GV->setSection("llvm.metadata");
+}
+
+Value *OpenMPIRBuilder::getGPUThreadID() {
+ return Builder.CreateCall(
+ getOrCreateRuntimeFunction(M,
+ OMPRTL___kmpc_get_hardware_thread_id_in_block),
+ {});
+}
+
+Value *OpenMPIRBuilder::getGPUWarpSize() {
+ return Builder.CreateCall(
+ getOrCreateRuntimeFunction(M, OMPRTL___kmpc_get_warp_size), {});
+}
+
+Value *OpenMPIRBuilder::getNVPTXWarpID() {
+ unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
+ return Builder.CreateAShr(getGPUThreadID(), LaneIDBits, "nvptx_warp_id");
+}
+
+Value *OpenMPIRBuilder::getNVPTXLaneID() {
+ unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
+ assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
+ unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
+ return Builder.CreateAnd(getGPUThreadID(), Builder.getInt32(LaneIDMask),
+ "nvptx_lane_id");
+}
+
+Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
+ Type *ToType) {
+ Type *FromType = From->getType();
+ uint64_t FromSize = M.getDataLayout().getTypeStoreSize(FromType);
+ uint64_t ToSize = M.getDataLayout().getTypeStoreSize(ToType);
+ assert(FromSize > 0 && "From size must be greater than zero");
+ assert(ToSize > 0 && "To size must be greater than zero");
+ if (FromType == ToType)
+ return From;
+ if (FromSize == ToSize)
+ return Builder.CreateBitCast(From, ToType);
+ if (ToType->isIntegerTy() && FromType->isIntegerTy())
+ return Builder.CreateIntCast(From, ToType, /*isSigned*/ true);
+ InsertPointTy SaveIP = Builder.saveIP();
+ Builder.restoreIP(AllocaIP);
+ Value *CastItem = Builder.CreateAlloca(ToType);
+ Builder.restoreIP(SaveIP);
+
+ Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ CastItem, FromType->getPointerTo());
+ Builder.CreateStore(From, ValCastItem);
+ return Builder.CreateLoad(ToType, CastItem);
+}
+
+Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
+ Value *Element,
+ Type *ElementType,
+ Value *Offset) {
+ uint64_t Size = M.getDataLayout().getTypeStoreSize(ElementType);
+ assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
+
+ // Cast all types to 32- or 64-bit values before calling shuffle routines.
+ Type *CastTy = Builder.getIntNTy(Size <= 4 ? 32 : 64);
+ Value *ElemCast = castValueToType(AllocaIP, Element, CastTy);
+ Value *WarpSize =
+ Builder.CreateIntCast(getGPUWarpSize(), Builder.getInt16Ty(), true);
+ Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
+ Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
+ : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
+ Value *WarpSizeCast =
+ Builder.CreateIntCast(WarpSize, Builder.getInt16Ty(), /*isSigned=*/true);
+ Value *ShuffleCall =
+ Builder.CreateCall(ShuffleFunc, {ElemCast, Offset, WarpSizeCast});
+ return castValueToType(AllocaIP, ShuffleCall, CastTy);
+}
+
+void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
+ Value *DstAddr, Type *ElemType,
+ Value *Offset, Type *ReductionArrayTy) {
+ uint64_t Size = M.getDataLayout().getTypeStoreSize(ElemType);
+ // Create the loop over the big sized data.
+ // ptr = (void*)Elem;
+ // ptrEnd = (void*) Elem + 1;
+ // Step = 8;
+ // while (ptr + Step < ptrEnd)
+ // shuffle((int64_t)*ptr);
+ // Step = 4;
+ // while (ptr + Step < ptrEnd)
+ // shuffle((int32_t)*ptr);
+ // ...
+ Type *IndexTy = Builder.getIndexTy(
+ M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+ Value *ElemPtr = DstAddr;
+ Value *Ptr = SrcAddr;
+ for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
+ if (Size < IntSize)
+ continue;
+ Type *IntType = Builder.getIntNTy(IntSize * 8);
+ Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ Ptr, IntType->getPointerTo(), Ptr->getName() + ".ascast");
+ Value *SrcAddrGEP =
+ Builder.CreateGEP(ElemType, SrcAddr, {ConstantInt::get(IndexTy, 1)});
+ ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ElemPtr, IntType->getPointerTo(), ElemPtr->getName() + ".ascast");
+
+ Function *CurFunc = Builder.GetInsertBlock()->getParent();
+ if ((Size / IntSize) > 1) {
+ Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ SrcAddrGEP, Builder.getPtrTy());
+ BasicBlock *PreCondBB =
+ BasicBlock::Create(M.getContext(), ".shuffle.pre_cond");
+ BasicBlock *ThenBB = BasicBlock::Create(M.getContext(), ".shuffle.then");
+ BasicBlock *ExitBB = BasicBlock::Create(M.getContext(), ".shuffle.exit");
+ BasicBlock *CurrentBB = Builder.GetInsertBlock();
+ emitBlock(PreCondBB, CurFunc);
+ PHINode *PhiSrc =
+ Builder.CreatePHI(Ptr->getType(), /*NumReservedValues=*/2);
+ PhiSrc->addIncoming(Ptr, CurrentBB);
+ PHINode *PhiDest =
+ Builder.CreatePHI(ElemPtr->getType(), /*NumReservedValues=*/2);
+ PhiDest->addIncoming(ElemPtr, CurrentBB);
+ Ptr = PhiSrc;
+ ElemPtr = PhiDest;
+ Value *PtrDiff = Builder.CreatePtrDiff(
+ Builder.getInt8Ty(), PtrEnd,
+ Builder.CreatePointerBitCastOrAddrSpaceCast(Ptr, Builder.getPtrTy()));
+ Builder.CreateCondBr(
+ Builder.CreateICmpSGT(PtrDiff, Builder.getInt64(IntSize - 1)), ThenBB,
+ ExitBB);
+ emitBlock(ThenBB, CurFunc);
+ Value *Res = createRuntimeShuffleFunction(
+ AllocaIP,
+ Builder.CreateAlignedLoad(
+ IntType, Ptr, M.getDataLayout().getPrefTypeAlign(ElemType)),
+ IntType, Offset);
+ Builder.CreateAlignedStore(Res, ElemPtr,
+ M.getDataLayout().getPrefTypeAlign(ElemType));
+ Value *LocalPtr =
+ Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
+ Value *LocalElemPtr =
+ Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
+ PhiSrc->addIncoming(LocalPtr, ThenBB);
+ PhiDest->addIncoming(LocalElemPtr, ThenBB);
+ emitBranch(PreCondBB);
+ emitBlock(ExitBB, CurFunc);
+ } else {
+ Value *Res = createRuntimeShuffleFunction(
+ AllocaIP, Builder.CreateLoad(IntType, Ptr), IntType, Offset);
+ if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
+ Res->getType()->getScalarSizeInBits())
+ Res = Builder.CreateTrunc(Res, ElemType);
+ Builder.CreateStore(Res, ElemPtr);
+ Ptr = Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
+ ElemPtr =
+ Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
+ }
+ Size = Size % IntSize;
+ }
+}
+
+void OpenMPIRBuilder::emitReductionListCopy(
+ InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
+ ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
+ CopyOptionsTy CopyOptions) {
+ Type *IndexTy = Builder.getIndexTy(
+ M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+ Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
+
+ // Iterates, element-by-element, through the source Reduce list and
+ // make a copy.
+ for (auto En : enumerate(ReductionInfos)) {
+ const ReductionInfo &RI = En.value();
+ Value *SrcElementAddr = nullptr;
+ Value *DestElementAddr = nullptr;
+ Value *DestElementPtrAddr = nullptr;
+ // Should we shuffle in an element from a remote lane?
+ bool ShuffleInElement = false;
+ // Set to true to update the pointer in the dest Reduce list to a
+ // newly created element.
+ bool UpdateDestListPtr = false;
+
+ // Step 1.1: Get the address for the src element in the Reduce list.
+ Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
+ ReductionArrayTy, SrcBase,
+ {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+ SrcElementAddr = Builder.CreateLoad(Builder.getPtrTy(), SrcElementPtrAddr);
+
+ // Step 1.2: Create a temporary to store the element in the destination
+ // Reduce list.
+ DestElementPtrAddr = Builder.CreateInBoundsGEP(
+ ReductionArrayTy, DestBase,
+ {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+ switch (Action) {
+ case CopyAction::RemoteLaneToThread: {
+ InsertPointTy CurIP = Builder.saveIP();
+ Builder.restoreIP(AllocaIP);
+ AllocaInst *DestAlloca = Builder.CreateAlloca(RI.ElementType, nullptr,
+ ".omp.reduction.element");
+ DestAlloca->setAlignment(
+ M.getDataLayout().getPrefTypeAlign(RI.ElementType));
+ DestElementAddr = DestAlloca;
+ DestElementAddr =
+ Builder.CreateAddrSpaceCast(DestElementAddr, Builder.getPtrTy(),
+ DestElementAddr->getName() + ".ascast");
+ Builder.restoreIP(CurIP);
+ ShuffleInElement = true;
+ UpdateDestListPtr = true;
+ break;
+ }
+ case CopyAction::ThreadCopy: {
+ DestElementAddr =
+ Builder.CreateLoad(Builder.getPtrTy(), DestElementPtrAddr);
+ break;
+ }
+ }
+
+ // Now that all active lanes have read the element in the
+ // Reduce list, shuffle over the value from the remote lane.
+ if (ShuffleInElement) {
+ shuffleAndStore(AllocaIP, SrcElementAddr, DestElementAddr, RI.ElementType,
+ RemoteLaneOffset, ReductionArrayTy);
+ } else {
+ switch (RI.EvaluationKind) {
+ case EvaluationKindTy::Scalar: {
+ Value *Elem = Builder.CreateLoad(RI.ElementType, SrcElementAddr);
+ // Store the source element value to the dest element address.
+ Builder.CreateStore(Elem, DestElementAddr);
+ break;
+ }
+ case EvaluationKindTy::Complex: {
+ break;
+ }
+ case EvaluationKindTy::Aggregate: {
+ Value *SizeVal = Builder.getInt64(
+ M.getDataLayout().getTypeStoreSize(RI.ElementType));
+ Builder.CreateMemCpy(
+ DestElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
+ SrcElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
+ SizeVal, false);
+ break;
+ }
+ };
+ }
+
+ // Step 3.1: Modify reference in dest Reduce list as needed.
+ // Modifying the reference in Reduce list to point to the newly
+ // created element. The element is live in the current function
+ // scope and that of functions it invokes (i.e., reduce_function).
+ // RemoteReduceData[i] = (void*)&RemoteElem
+ if (UpdateDestListPtr) {
+ Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ DestElementAddr, Builder.getPtrTy(),
+ DestElementAddr->getName() + ".ascast");
+ Builder.CreateStore(CastDestAddr, DestElementPtrAddr);
+ }
+ }
+}
+
+Function *OpenMPIRBuilder::emitInterWarpCopyFunction(
+ const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
+ AttributeList FuncAttrs) {
+ InsertPointTy SavedIP = Builder.saveIP();
+ LLVMContext &Ctx = M.getContext();
+ FunctionType *FuncTy = FunctionType::get(
+ Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getInt32Ty()},
+ /* IsVarArg */ false);
+ Function *WcFunc =
+ Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+ "_omp_reduction_inter_warp_copy_func", &M);
+ WcFunc->setAttributes(FuncAttrs);
+ WcFunc->addParamAttr(0, Attribute::NoUndef);
+ WcFunc->addParamAttr(1, Attribute::NoUndef);
+ BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", WcFunc);
+ Builder.SetInsertPoint(EntryBB);
+
+ // ReduceList: thread local Reduce list.
+ // At the stage of the computation when this function is called, partially
+ // aggregated values reside in the first lane of every active warp.
+ Argument *ReduceListArg = WcFunc->getArg(0);
+ // NumWarps: number of warps active in the parallel region. This could
+ // be smaller than 32 (max warps in a CTA) for partial block reduction.
+ Argument *NumWarpsArg = WcFunc->getArg(1);
+
+ // This array is used as a medium to transfer, one reduce element at a time,
+ // the data from the first lane of every warp to lanes in the first warp
+ // in order to perform the final step of a reduction in a parallel region
+ // (reduction across warps). The array is placed in NVPTX __shared__ memory
+ // for reduced latency, as well as to have a distinct copy for concurrently
+ // executing target regions. The array is declared with common linkage so
+ // as to be shared across compilation units.
+ StringRef TransferMediumName =
+ "__openmp_nvptx_data_transfer_temporary_storage";
+ GlobalVariable *TransferMedium = M.getGlobalVariable(TransferMediumName);
+ unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
+ ArrayType *ArrayTy = ArrayType::get(Builder.getInt32Ty(), WarpSize);
+ if (!TransferMedium) {
+ TransferMedium = new GlobalVariable(
+ M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
+ UndefValue::get(ArrayTy), TransferMediumName,
+ /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
+ /*AddressSpace=*/3);
+ }
+
+ uint32_t SrcLocStrSize;
+ Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+
+ // Get the CUDA thread id of the current OpenMP thread on the GPU.
+ Value *GPUThreadID = getGPUThreadID();
+ // nvptx_lane_id = nvptx_id % warpsize
+ Value *LaneID = getNVPTXLaneID();
+ // nvptx_warp_id = nvptx_id / warpsize
+ Value *WarpID = getNVPTXWarpID();
+
+ InsertPointTy AllocaIP =
+ InsertPointTy(Builder.GetInsertBlock(),
+ Builder.GetInsertBlock()->getFirstInsertionPt());
+ Type *Arg0Type = ReduceListArg->getType();
+ Type *Arg1Type = NumWarpsArg->getType();
+ Builder.restoreIP(AllocaIP);
+ AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
+ Arg0Type, nullptr, ReduceListArg->getName() + ".addr");
+ AllocaInst *NumWarpsAlloca =
+ Builder.CreateAlloca(Arg1Type, nullptr, NumWarpsArg->getName() + ".addr");
+ Value *ThreadID =
+ getOrCreateThreadID(getOrCreateIdent(SrcLocStr, SrcLocStrSize));
+ Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ReduceListAlloca, Arg0Type, ReduceListAlloca->getName() + ".ascast");
+ Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ NumWarpsAlloca, Arg1Type->getPointerTo(),
+ NumWarpsAlloca->getName() + ".ascast");
+ Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
+ Builder.CreateStore(NumWarpsArg, NumWarpsAddrCast);
+ AllocaIP = getInsertPointAfterInstr(NumWarpsAlloca);
+ InsertPointTy CodeGenIP =
+ getInsertPointAfterInstr(&Builder.GetInsertBlock()->back());
+ Builder.restoreIP(CodeGenIP);
+
+ Value *ReduceList =
+ Builder.CreateLoad(Builder.getPtrTy(), ReduceListAddrCast);
+
+ for (auto En : enumerate(ReductionInfos)) {
+ //
+ // Warp master copies reduce element to transfer medium in __shared__
+ // memory.
+ //
+ const ReductionInfo &RI = En.value();
+ unsigned RealTySize = M.getDataLayout().getTypeAllocSize(RI.ElementType);
+ for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
+ Type *CType = Builder.getIntNTy(TySize * 8);
+
+ unsigned NumIters = RealTySize / TySize;
+ if (NumIters == 0)
+ continue;
+ Value *Cnt = nullptr;
+ Value *CntAddr = nullptr;
+ BasicBlock *PrecondBB = nullptr;
+ BasicBlock *ExitBB = nullptr;
+ if (NumIters > 1) {
+ CodeGenIP = Builder.saveIP();
+ Builder.restoreIP(AllocaIP);
+ CntAddr =
+ Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, ".cnt.addr");
+
+ CntAddr = Builder.CreateAddrSpaceCast(CntAddr, Builder.getPtrTy(),
+ CntAddr->getName() + ".ascast");
+ Builder.restoreIP(CodeGenIP);
+ Builder.CreateStore(Constant::getNullValue(Builder.getInt32Ty()),
+ CntAddr,
+ /*Volatile=*/false);
+ PrecondBB = BasicBlock::Create(Ctx, "precond");
+ ExitBB = BasicBlock::Create(Ctx, "exit");
+ BasicBlock *BodyBB = BasicBlock::Create(Ctx, "body");
+ emitBlock(PrecondBB, Builder.GetInsertBlock()->getParent());
+ Cnt = Builder.CreateLoad(Builder.getInt32Ty(), CntAddr,
+ /*Volatile=*/false);
+ Value *Cmp = Builder.CreateICmpULT(
+ Cnt, ConstantInt::get(Builder.getInt32Ty(), NumIters));
+ Builder.CreateCondBr(Cmp, BodyBB, ExitBB);
+ emitBlock(BodyBB, Builder.GetInsertBlock()->getParent());
+ }
+
+ // kmpc_barrier.
+ createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
+ omp::Directive::OMPD_unknown,
+ /* ForceSimpleCall */ false,
+ /* CheckCancelFlag */ true, ThreadID);
+ BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
+ BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
+ BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
+
+ // if (lane_id == 0)
+ Value *IsWarpMaster = Builder.CreateIsNull(LaneID, "warp_master");
+ Builder.CreateCondBr(IsWarpMaster, ThenBB, ElseBB);
+ emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
+
+ // Reduce element = LocalReduceList[i]
+ auto *RedListArrayTy =
+ ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+ Type *IndexTy = Builder.getIndexTy(
+ M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+ Value *ElemPtrPtr =
+ Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
+ {ConstantInt::get(IndexTy, 0),
+ ConstantInt::get(IndexTy, En.index())});
+ // elemptr = ((CopyType*)(elemptrptr)) + I
+ Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
+ if (NumIters > 1)
+ ElemPtr = Builder.CreateGEP(Builder.getInt32Ty(), ElemPtr, Cnt);
+
+ // Get pointer to location in transfer medium.
+ // MediumPtr = &medium[warp_id]
+ Value *MediumPtr = Builder.CreateInBoundsGEP(
+ ArrayTy, TransferMedium, {Builder.getInt64(0), WarpID});
+ // elem = *elemptr
+ //*MediumPtr = elem
+ Value *Elem = Builder.CreateLoad(CType, ElemPtr);
+ // Store the source element value to the dest element address.
+ Builder.CreateStore(Elem, MediumPtr,
+ /*IsVolatile*/ true);
+ Builder.CreateBr(MergeBB);
+
+ // else
+ emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
+ Builder.CreateBr(MergeBB);
+
+ // endif
+ emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
+ createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
+ omp::Directive::OMPD_unknown,
+ /* ForceSimpleCall */ false,
+ /* CheckCancelFlag */ true, ThreadID);
+
+ // Warp 0 copies reduce element from transfer medium
+ BasicBlock *W0ThenBB = BasicBlock::Create(Ctx, "then");
+ BasicBlock *W0ElseBB = BasicBlock::Create(Ctx, "else");
+ BasicBlock *W0MergeBB = BasicBlock::Create(Ctx, "ifcont");
+
+ Value *NumWarpsVal =
+ Builder.CreateLoad(Builder.getInt32Ty(), NumWarpsAddrCast);
+ // Up to 32 threads in warp 0 are active.
+ Value *IsActiveThread =
+ Builder.CreateICmpULT(GPUThreadID, NumWarpsVal, "is_active_thread");
+ Builder.CreateCondBr(IsActiveThread, W0ThenBB, W0ElseBB);
+
+ emitBlock(W0ThenBB, Builder.GetInsertBlock()->getParent());
+
+ // SecMediumPtr = &medium[tid]
+ // SrcMediumVal = *SrcMediumPtr
+ Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
+ ArrayTy, TransferMedium, {Builder.getInt64(0), GPUThreadID});
+ // TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
+ Value *TargetElemPtrPtr =
+ Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
+ {ConstantInt::get(IndexTy, 0),
+ ConstantInt::get(IndexTy, En.index())});
+ Value *TargetElemPtrVal =
+ Builder.CreateLoad(Builder.getPtrTy(), TargetElemPtrPtr);
+ Value *TargetElemPtr = TargetElemPtrVal;
+ if (NumIters > 1)
+ TargetElemPtr =
+ Builder.CreateGEP(Builder.getInt32Ty(), TargetElemPtr, Cnt);
+
+ // *TargetElemPtr = SrcMediumVal;
+ Value *SrcMediumValue =
+ Builder.CreateLoad(CType, SrcMediumPtrVal, /*IsVolatile*/ true);
+ Builder.CreateStore(SrcMediumValue, TargetElemPtr);
+ Builder.CreateBr(W0MergeBB);
+
+ emitBlock(W0ElseBB, Builder.GetInsertBlock()->getParent());
+ Builder.CreateBr(W0MergeBB);
+
+ emitBlock(W0MergeBB, Builder.GetInsertBlock()->getParent());
+
+ if (NumIters > 1) {
+ Cnt = Builder.CreateNSWAdd(
+ Cnt, ConstantInt::get(Builder.getInt32Ty(), /*V=*/1));
+ Builder.CreateStore(Cnt, CntAddr, /*Volatile=*/false);
+
+ auto *CurFn = Builder.GetInsertBlock()->getParent();
+ emitBranch(PrecondBB);
+ emitBlock(ExitBB, CurFn);
+ }
+ RealTySize %= TySize;
+ }
+ }
+
+ Builder.CreateRetVoid();
+ Builder.restoreIP(SavedIP);
+
+ return WcFunc;
+}
+
+Function *OpenMPIRBuilder::emitShuffleAndReduceFunction(
+ ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
+ AttributeList FuncAttrs) {
+ LLVMContext &Ctx = M.getContext();
+ FunctionType *FuncTy =
+ FunctionType::get(Builder.getVoidTy(),
+ {Builder.getPtrTy(), Builder.getInt16Ty(),
+ Builder.getInt16Ty(), Builder.getInt16Ty()},
+ /* IsVarArg */ false);
+ Function *SarFunc =
+ Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+ "_omp_reduction_shuffle_and_reduce_func", &M);
+ SarFunc->setAttributes(FuncAttrs);
+ SarFunc->addParamAttr(0, Attribute::NoUndef);
+ SarFunc->addParamAttr(1, Attribute::NoUndef);
+ SarFunc->addParamAttr(2, Attribute::NoUndef);
+ SarFunc->addParamAttr(3, Attribute::NoUndef);
+ SarFunc->addParamAttr(1, Attribute::SExt);
+ SarFunc->addParamAttr(2, Attribute::SExt);
+ SarFunc->addParamAttr(3, Attribute::SExt);
+ BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", SarFunc);
+ Builder.SetInsertPoint(EntryBB);
+
+ // Thread local Reduce list used to host the values of data to be reduced.
+ Argument *ReduceListArg = SarFunc->getArg(0);
+ // Current lane id; could be logical.
+ Argument *LaneIDArg = SarFunc->getArg(1);
+ // Offset of the remote source lane relative to the current lane.
+ Argument *RemoteLaneOffsetArg = SarFunc->getArg(2);
+ // Algorithm version. This is expected to be known at compile time.
+ Argument *AlgoVerArg = SarFunc->getArg(3);
+
+ Type *ReduceListArgType = ReduceListArg->getType();
+ Type *LaneIDArgType = LaneIDArg->getType();
+ Type *LaneIDArgPtrType = LaneIDArg->getType()->getPointerTo();
+ Value *ReduceListAlloca = Builder.CreateAlloca(
+ ReduceListArgType, nullptr, ReduceListArg->getName() + ".addr");
+ Value *LaneIdAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
+ LaneIDArg->getName() + ".addr");
+ Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
+ LaneIDArgType, nullptr, RemoteLaneOffsetArg->getName() + ".addr");
+ Value *AlgoVerAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
+ AlgoVerArg->getName() + ".addr");
+ ArrayType *RedListArrayTy =
+ ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+
+ // Create a local thread-private variable to host the Reduce list
+ // from a remote lane.
+ Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
+ RedListArrayTy, nullptr, ".omp.reduction.remote_reduce_list");
+
+ Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ReduceListAlloca, ReduceListArgType,
+ ReduceListAlloca->getName() + ".ascast");
+ Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ LaneIdAlloca, LaneIDArgPtrType, LaneIdAlloca->getName() + ".ascast");
+ Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ RemoteLaneOffsetAlloca, LaneIDArgPtrType,
+ RemoteLaneOffsetAlloca->getName() + ".ascast");
+ Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ AlgoVerAlloca, LaneIDArgPtrType, AlgoVerAlloca->getName() + ".ascast");
+ Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ RemoteReductionListAlloca, Builder.getPtrTy(),
+ RemoteReductionListAlloca->getName() + ".ascast");
+
+ Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
+ Builder.CreateStore(LaneIDArg, LaneIdAddrCast);
+ Builder.CreateStore(RemoteLaneOffsetArg, RemoteLaneOffsetAddrCast);
+ Builder.CreateStore(AlgoVerArg, AlgoVerAddrCast);
+
+ Value *ReduceList = Builder.CreateLoad(ReduceListArgType, ReduceListAddrCast);
+ Value *LaneId = Builder.CreateLoad(LaneIDArgType, LaneIdAddrCast);
+ Value *RemoteLaneOffset =
+ Builder.CreateLoad(LaneIDArgType, RemoteLaneOffsetAddrCast);
+ Value *AlgoVer = Builder.CreateLoad(LaneIDArgType, AlgoVerAddrCast);
+
+ InsertPointTy AllocaIP = getInsertPointAfterInstr(RemoteReductionListAlloca);
+
+ // This loop iterates through the list of reduce elements and copies,
+ // element by element, from a remote lane in the warp to RemoteReduceList,
+ // hosted on the thread's stack.
+ emitReductionListCopy(
+ AllocaIP, CopyAction::RemoteLaneToThread, RedListArrayTy, ReductionInfos,
+ ReduceList, RemoteListAddrCast, {RemoteLaneOffset, nullptr, nullptr});
+
+ // The actions to be performed on the Remote Reduce list is dependent
+ // on the algorithm version.
+ //
+ // if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
+ // LaneId % 2 == 0 && Offset > 0):
+ // do the reduction value aggregation
+ //
+ // The thread local variable Reduce list is mutated in place to host the
+ // reduced data, which is the aggregated value produced from local and
+ // remote lanes.
+ //
+ // Note that AlgoVer is expected to be a constant integer known at compile
+ // time.
+ // When AlgoVer==0, the first conjunction evaluates to true, making
+ // the entire predicate true during compile time.
+ // When AlgoVer==1, the second conjunction has only the second part to be
+ // evaluated during runtime. Other conjunctions evaluates to false
+ // during compile time.
+ // When AlgoVer==2, the third conjunction has only the second part to be
+ // evaluated during runtime. Other conjunctions evaluates to false
+ // during compile time.
+ Value *CondAlgo0 = Builder.CreateIsNull(AlgoVer);
+ Value *Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
+ Value *LaneComp = Builder.CreateICmpULT(LaneId, RemoteLaneOffset);
+ Value *CondAlgo1 = Builder.CreateAnd(Algo1, LaneComp);
+ Value *Algo2 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(2));
+ Value *LaneIdAnd1 = Builder.CreateAnd(LaneId, Builder.getInt16(1));
+ Value *LaneIdComp = Builder.CreateIsNull(LaneIdAnd1);
+ Value *Algo2AndLaneIdComp = Builder.CreateAnd(Algo2, LaneIdComp);
+ Value *RemoteOffsetComp =
+ Builder.CreateICmpSGT(RemoteLaneOffset, Builder.getInt16(0));
+ Value *CondAlgo2 = Builder.CreateAnd(Algo2AndLaneIdComp, RemoteOffsetComp);
+ Value *CA0OrCA1 = Builder.CreateOr(CondAlgo0, CondAlgo1);
+ Value *CondReduce = Builder.CreateOr(CA0OrCA1, CondAlgo2);
+
+ BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
+ BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
+ BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
+
+ Builder.CreateCondBr(CondReduce, ThenBB, ElseBB);
+ emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
+ // reduce_function(LocalReduceList, RemoteReduceList)
+ Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ReduceList, Builder.getPtrTy());
+ Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ RemoteListAddrCast, Builder.getPtrTy());
+ Builder.CreateCall(ReduceFn, {LocalReduceListPtr, RemoteReduceListPtr})
+ ->addFnAttr(Attribute::NoUnwind);
+ Builder.CreateBr(MergeBB);
+
+ emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
+ Builder.CreateBr(MergeBB);
+
+ emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
+
+ // if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
+ // Reduce list.
+ Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
+ Value *LaneIdGtOffset = Builder.CreateICmpUGE(LaneId, RemoteLaneOffset);
+ Value *CondCopy = Builder.CreateAnd(Algo1, LaneIdGtOffset);
+
+ BasicBlock *CpyThenBB = BasicBlock::Create(Ctx, "then");
+ BasicBlock *CpyElseBB = BasicBlock::Create(Ctx, "else");
+ BasicBlock *CpyMergeBB = BasicBlock::Create(Ctx, "ifcont");
+ Builder.CreateCondBr(CondCopy, CpyThenBB, CpyElseBB);
+
+ emitBlock(CpyThenBB, Builder.GetInsertBlock()->getParent());
+ emitReductionListCopy(AllocaIP, CopyAction::ThreadCopy, RedListArrayTy,
+ ReductionInfos, RemoteListAddrCast, ReduceList);
+ Builder.CreateBr(CpyMergeBB);
+
+ emitBlock(CpyElseBB, Builder.GetInsertBlock()->getParent());
+ Builder.CreateBr(CpyMergeBB);
+
+ emitBlock(CpyMergeBB, Builder.GetInsertBlock()->getParent());
+
+ Builder.CreateRetVoid();
+
+ return SarFunc;
+}
+
+Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
+ ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
+ AttributeList FuncAttrs) {
+ OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
+ LLVMContext &Ctx = M.getContext();
+ FunctionType *FuncTy = FunctionType::get(
+ Builder.getVoidTy(),
+ {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
+ /* IsVarArg */ false);
+ Function *LtGCFunc =
+ Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+ "_omp_reduction_list_to_global_copy_func", &M);
+ LtGCFunc->setAttributes(FuncAttrs);
+ LtGCFunc->addParamAttr(0, Attribute::NoUndef);
+ LtGCFunc->addParamAttr(1, Attribute::NoUndef);
+ LtGCFunc->addParamAttr(2, Attribute::NoUndef);
+
+ BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
+ Builder.SetInsertPoint(EntryBlock);
+
+ // Buffer: global reduction buffer.
+ Argument *BufferArg = LtGCFunc->getArg(0);
+ // Idx: index of the buffer.
+ Argument *IdxArg = LtGCFunc->getArg(1);
+ // ReduceList: thread local Reduce list.
+ Argument *ReduceListArg = LtGCFunc->getArg(2);
+
+ Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
+ BufferArg->getName() + ".addr");
+ Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
+ IdxArg->getName() + ".addr");
+ Value *ReduceListArgAlloca = Builder.CreateAlloca(
+ Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
+ Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ BufferArgAlloca, Builder.getPtrTy(),
+ BufferArgAlloca->getName() + ".ascast");
+ Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
+ Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ReduceListArgAlloca, Builder.getPtrTy(),
+ ReduceListArgAlloca->getName() + ".ascast");
+
+ Builder.CreateStore(BufferArg, BufferArgAddrCast);
+ Builder.CreateStore(IdxArg, IdxArgAddrCast);
+ Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
+
+ Value *LocalReduceList =
+ Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+ Value *BufferArgVal =
+ Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
+ Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
+ Type *IndexTy = Builder.getIndexTy(
+ M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+ for (auto En : enumerate(ReductionInfos)) {
+ const ReductionInfo &RI = En.value();
+ auto *RedListArrayTy =
+ ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+ // Reduce element = LocalReduceList[i]
+ Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
+ RedListArrayTy, LocalReduceList,
+ {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+ // elemptr = ((CopyType*)(elemptrptr)) + I
+ Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
+
+ // Global = Buffer.VD[Idx];
+ Value *BufferVD =
+ Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferArgVal, Idxs);
+ Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
+ ReductionsBufferTy, BufferVD, 0, En.index(), "sum");
+
+ switch (RI.EvaluationKind) {
+ case EvaluationKindTy::Scalar: {
+ Value *TargetElement = Builder.CreateLoad(RI.ElementType, ElemPtr);
+ Builder.CreateStore(TargetElement, GlobVal);
+ break;
+ }
+ case EvaluationKindTy::Complex: {
+ break;
+ }
+ case EvaluationKindTy::Aggregate:
+ Value *SizeVal =
+ Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
+ Builder.CreateMemCpy(
+ GlobVal, M.getDataLayout().getPrefTypeAlign(RI.ElementType), ElemPtr,
+ M.getDataLayout().getPrefTypeAlign(RI.ElementType), SizeVal, false);
+ break;
+ }
+ }
+
+ Builder.CreateRetVoid();
+ Builder.restoreIP(OldIP);
+ return LtGCFunc;
+}
+
+Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
+ ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
+ Type *ReductionsBufferTy, AttributeList FuncAttrs) {
+ OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
+ LLVMContext &Ctx = M.getContext();
+ FunctionType *FuncTy = FunctionType::get(
+ Builder.getVoidTy(),
+ {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
+ /* IsVarArg */ false);
+ Function *LtGRFunc =
+ Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+ "_omp_reduction_list_to_global_reduce_func", &M);
+ LtGRFunc->setAttributes(FuncAttrs);
+ LtGRFunc->addParamAttr(0, Attribute::NoUndef);
+ LtGRFunc->addParamAttr(1, Attribute::NoUndef);
+ LtGRFunc->addParamAttr(2, Attribute::NoUndef);
+
+ BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
+ Builder.SetInsertPoint(EntryBlock);
+
+ // Buffer: global reduction buffer.
+ Argument *BufferArg = LtGRFunc->getArg(0);
+ // Idx: index of the buffer.
+ Argument *IdxArg = LtGRFunc->getArg(1);
+ // ReduceList: thread local Reduce list.
+ Argument *ReduceListArg = LtGRFunc->getArg(2);
+
+ Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
+ BufferArg->getName() + ".addr");
+ Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
+ IdxArg->getName() + ".addr");
+ Value *ReduceListArgAlloca = Builder.CreateAlloca(
+ Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
+ auto *RedListArrayTy =
+ ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+
+ // 1. Build a list of reduction variables.
+ // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
+ Value *LocalReduceList =
+ Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
+
+ Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ BufferArgAlloca, Builder.getPtrTy(),
+ BufferArgAlloca->getName() + ".ascast");
+ Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
+ Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ReduceListArgAlloca, Builder.getPtrTy(),
+ ReduceListArgAlloca->getName() + ".ascast");
+ Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ LocalReduceList, Builder.getPtrTy(),
+ LocalReduceList->getName() + ".ascast");
+
+ Builder.CreateStore(BufferArg, BufferArgAddrCast);
+ Builder.CreateStore(IdxArg, IdxArgAddrCast);
+ Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
+
+ Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
+ Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
+ Type *IndexTy = Builder.getIndexTy(
+ M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+ for (auto En : enumerate(ReductionInfos)) {
+ Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
+ RedListArrayTy, LocalReduceListAddrCast,
+ {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+ Value *BufferVD =
+ Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
+ // Global = Buffer.VD[Idx];
+ Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
+ ReductionsBufferTy, BufferVD, 0, En.index(), "sum");
+ Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+ }
+
+ // Call reduce_function(GlobalReduceList, ReduceList)
+ Value *ReduceList =
+ Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+ Builder.CreateCall(ReduceFn, {LocalReduceListAddrCast, ReduceList})
+ ->addFnAttr(Attribute::NoUnwind);
+ Builder.CreateRetVoid();
+ Builder.restoreIP(OldIP);
+ return LtGRFunc;
+}
+
+Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
+ ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
+ AttributeList FuncAttrs) {
+ OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
+ LLVMContext &Ctx = M.getContext();
+ FunctionType *FuncTy = FunctionType::get(
+ Builder.getVoidTy(),
+ {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
+ /* IsVarArg */ false);
+ Function *LtGCFunc =
+ Function::Create(FuncTy, GlobalVariable::InternalLinkage,
+ "_omp_reduction_global_to_list_copy_func", &M);
+ LtGCFunc->setAttributes(FuncAttrs);
+ LtGCFunc->addParamAttr(0, Attribute::NoUndef);
+ LtGCFunc->addParamAttr(1, Attribute::NoUndef);
+ LtGCFunc->addParamAttr(2, Attribute::NoUndef);
+
+ BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
+ Builder.SetInsertPoint(EntryBlock);
+
+ // Buffer: global reduction buffer.
+ Argument *BufferArg = LtGCFunc->getArg(0);
+ // Idx: index of the buffer.
+ Argument *IdxArg = LtGCFunc->getArg(1);
+ // ReduceList: thread local Reduce list.
+ Argument *ReduceListArg = LtGCFunc->getArg(2);
+
+ Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
+ BufferArg->getName() + ".addr");
+ Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
+ IdxArg->getName() + ".addr");
+ Value *ReduceListArgAlloca = Builder.CreateAlloca(
+ Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
+ Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ BufferArgAlloca, Builder.getPtrTy(),
+ BufferArgAlloca->getName() + ".ascast");
+ Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
+ Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ReduceListArgAlloca, Builder.getPtrTy(),
+ ReduceListArgAlloca->getName() + ".ascast");
+ Builder.CreateStore(BufferArg, BufferArgAddrCast);
+ Builder.CreateStore(IdxArg, IdxArgAddrCast);
+ Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
+
+ Value *LocalReduceList =
+ Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
+ Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
+ Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
+ Type *IndexTy = Builder.getIndexTy(
+ M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+ for (auto En : enumerate(ReductionInfos)) {
+ const OpenMPIRBuilder::ReductionInfo &RI = En.value();
+ auto *RedListArrayTy =
+ ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
+ // Reduce element = LocalReduceList[i]
+ Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
+ RedListArrayTy, LocalReduceList,
+ {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+ // elemptr = ((CopyType*)(elemptrptr)) + I
+ Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
+ // Global = Buffer.VD[Idx];
+ Value *BufferVD =
+ Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
+ Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
+ ReductionsBufferTy, BufferVD, 0, En.index(), "sum");
+
+ switch (RI.EvaluationKind) {
+ case EvaluationKindTy::Scalar: {
+ Value *TargetElement = Builder.CreateLoad(RI.ElementType, GlobValPtr);
+ Builder.CreateStore(TargetElement, ElemPtr);
+ break;
+ }
+ case EvaluationKindTy::Complex: {
+ // FIXME(Jan): Complex type
----------------
TIFitis wrote:
Yes there are currently no tests for the Complex type, only for scalar and aggregate.
I don't see any other handling of complex type in the OMPIRBuilder, so I'm working on what needs to be done here. In the meanwhile, I've added assertion failures for hitting complex types.
https://github.com/llvm/llvm-project/pull/80343
More information about the Mlir-commits
mailing list