[llvm] [OpenMP] [IR Builder] Changes to Support Scan Operation (PR #136035)
Anchu Rajendran S via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 16 11:33:03 PDT 2025
================
@@ -4011,6 +4013,336 @@ OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
/*Conditional*/ true, /*hasFinalize*/ true);
}
+llvm::CallInst *
+OpenMPIRBuilder::emitNoUnwindRuntimeCall(llvm::FunctionCallee Callee,
+ ArrayRef<llvm::Value *> Args,
+ const llvm::Twine &Name) {
+ llvm::CallInst *Call = Builder.CreateCall(
+ Callee, Args, SmallVector<llvm::OperandBundleDef, 1>(), Name);
+ Call->setDoesNotThrow();
+ return Call;
+}
+
+// Expects input basic block is dominated by BeforeScanBB.
+// Once Scan directive is encountered, the code after scan directive should be
+// dominated by AfterScanBB. Scan directive splits the code sequence to
+// scan and input phase. Based on whether inclusive or exclusive
+// clause is used in the scan directive and whether input loop or scan loop
+// is lowered, it adds jumps to input and scan phase. First Scan loop is the
+// input loop and second is the scan loop. The code generated handles only
+// inclusive scans now.
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
+ const LocationDescription &Loc, InsertPointTy AllocaIP,
+ ArrayRef<llvm::Value *> ScanVars, ArrayRef<llvm::Type *> ScanVarsType,
+ bool IsInclusive) {
+ if (ScanInfo.OMPFirstScanLoop) {
+ llvm::Error Err =
+ emitScanBasedDirectiveDeclsIR(AllocaIP, ScanVars, ScanVarsType);
+ if (Err) {
+ return Err;
+ }
+ }
+ if (!updateToLocation(Loc))
+ return Loc.IP;
+
+ llvm::Value *IV = ScanInfo.IV;
+
+ if (ScanInfo.OMPFirstScanLoop) {
+ // Emit buffer[i] = red; at the end of the input phase.
+ for (size_t i = 0; i < ScanVars.size(); i++) {
+ Value *BuffPtr = ScanInfo.ScanBuffPtrs[ScanVars[i]];
+ Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr);
+ Type *DestTy = ScanVarsType[i];
+ Value *Val = Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset");
+ Value *Src = Builder.CreateLoad(DestTy, ScanVars[i]);
+
+ Builder.CreateStore(Src, Val);
+ }
+ }
+ Builder.CreateBr(ScanInfo.OMPScanLoopExit);
+ emitBlock(ScanInfo.OMPScanDispatch, Builder.GetInsertBlock()->getParent());
+
+ if (!ScanInfo.OMPFirstScanLoop) {
+ IV = ScanInfo.IV;
+ // Emit red = buffer[i]; at the entrance to the scan phase.
+ // TODO: if exclusive scan, the red = buffer[i-1] needs to be updated.
+ for (size_t i = 0; i < ScanVars.size(); i++) {
+ Value *BuffPtr = ScanInfo.ScanBuffPtrs[ScanVars[i]];
+ Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr);
+ Type *DestTy = ScanVarsType[i];
+ Value *SrcPtr =
+ Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset");
+ Value *Src = Builder.CreateLoad(DestTy, SrcPtr);
+ Builder.CreateStore(Src, ScanVars[i]);
+ }
+ }
+
+ // TODO: Update it to CreateBr and remove dead blocks
+ llvm::Value *CmpI = Builder.getInt1(true);
+ if (ScanInfo.OMPFirstScanLoop == IsInclusive) {
+ Builder.CreateCondBr(CmpI, ScanInfo.OMPBeforeScanBlock,
+ ScanInfo.OMPAfterScanBlock);
+ } else {
+ Builder.CreateCondBr(CmpI, ScanInfo.OMPAfterScanBlock,
+ ScanInfo.OMPBeforeScanBlock);
+ }
+ emitBlock(ScanInfo.OMPAfterScanBlock, Builder.GetInsertBlock()->getParent());
+ Builder.SetInsertPoint(ScanInfo.OMPAfterScanBlock);
+ return Builder.saveIP();
+}
+
+Error OpenMPIRBuilder::emitScanBasedDirectiveDeclsIR(
+ InsertPointTy AllocaIP, ArrayRef<Value *> ScanVars,
+ ArrayRef<Type *> ScanVarsType) {
+
+ Builder.restoreIP(AllocaIP);
+ // Create the shared pointer at alloca IP.
+ for (size_t i = 0; i < ScanVars.size(); i++) {
+ llvm::Value *BuffPtr =
+ Builder.CreateAlloca(Builder.getPtrTy(), nullptr, "vla");
+ ScanInfo.ScanBuffPtrs[ScanVars[i]] = BuffPtr;
+ }
+
+ // Allocate temporary buffer by master thread
+ auto BodyGenCB = [&](InsertPointTy AllocaIP,
+ InsertPointTy CodeGenIP) -> Error {
+ Builder.restoreIP(CodeGenIP);
+ Value *AllocSpan = Builder.CreateAdd(ScanInfo.Span, Builder.getInt32(1));
+ for (size_t i = 0; i < ScanVars.size(); i++) {
+ Type *IntPtrTy = Builder.getInt32Ty();
+ Constant *Allocsize = ConstantExpr::getSizeOf(ScanVarsType[i]);
+ Allocsize = ConstantExpr::getTruncOrBitCast(Allocsize, IntPtrTy);
+ Value *Buff = Builder.CreateMalloc(IntPtrTy, ScanVarsType[i], Allocsize,
+ AllocSpan, nullptr, "arr");
+ Builder.CreateStore(Buff, ScanInfo.ScanBuffPtrs[ScanVars[i]]);
+ }
+ return Error::success();
+ };
+ // TODO: Perform finalization actions for variables. This has to be
+ // called for variables which have destructors/finalizers.
+ auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
+
+ Builder.SetInsertPoint(ScanInfo.OMPScanInit->getTerminator());
+ llvm::Value *FilterVal = Builder.getInt32(0);
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+ createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal);
+
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.restoreIP(*AfterIP);
+ BasicBlock *InputBB = Builder.GetInsertBlock();
+ if (InputBB->getTerminator())
+ Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
+ AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.restoreIP(*AfterIP);
+
+ return Error::success();
+}
+
+Error OpenMPIRBuilder::emitScanBasedDirectiveFinalsIR(
+ SmallVector<ReductionInfo> ReductionInfos) {
+ auto BodyGenCB = [&](InsertPointTy AllocaIP,
+ InsertPointTy CodeGenIP) -> Error {
+ Builder.restoreIP(CodeGenIP);
+ for (ReductionInfo RedInfo : ReductionInfos) {
+ Value *PrivateVar = RedInfo.PrivateVariable;
+ Value *OrigVar = RedInfo.Variable;
+ Value *BuffPtr = ScanInfo.ScanBuffPtrs[PrivateVar];
+ Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr);
+
+ Type *SrcTy = RedInfo.ElementType;
+ Value *Val =
+ Builder.CreateInBoundsGEP(SrcTy, Buff, ScanInfo.Span, "arrayOffset");
+ Value *Src = Builder.CreateLoad(SrcTy, Val);
+
+ Builder.CreateStore(Src, OrigVar);
+ Builder.CreateFree(Buff);
+ }
+ return Error::success();
+ };
+ // TODO: Perform finalization actions for variables. This has to be
+ // called for variables which have destructors/finalizers.
+ auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
+
+ if (ScanInfo.OMPScanFinish->getTerminator())
+ Builder.SetInsertPoint(ScanInfo.OMPScanFinish->getTerminator());
+ else
+ Builder.SetInsertPoint(ScanInfo.OMPScanFinish);
+
+ llvm::Value *FilterVal = Builder.getInt32(0);
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+ createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal);
+
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.restoreIP(*AfterIP);
+ BasicBlock *InputBB = Builder.GetInsertBlock();
+ if (InputBB->getTerminator())
+ Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
+ AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.restoreIP(*AfterIP);
+ return Error::success();
+}
+
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
+ const LocationDescription &Loc,
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos) {
+
+ if (!updateToLocation(Loc))
+ return Loc.IP;
+ auto BodyGenCB = [&](InsertPointTy AllocaIP,
+ InsertPointTy CodeGenIP) -> Error {
+ Builder.restoreIP(CodeGenIP);
+ Function *CurFn = Builder.GetInsertBlock()->getParent();
+ // for (int k = 0; k <= ceil(log2(n)); ++k)
+ llvm::BasicBlock *LoopBB =
+ BasicBlock::Create(CurFn->getContext(), "omp.outer.log.scan.body");
+ llvm::BasicBlock *ExitBB =
+ splitBB(Builder, false, "omp.outer.log.scan.exit");
+ llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration(
+ Builder.GetInsertBlock()->getModule(),
+ (llvm::Intrinsic::ID)llvm::Intrinsic::log2, Builder.getDoubleTy());
+ llvm::BasicBlock *InputBB = Builder.GetInsertBlock();
+ llvm::Value *Arg =
+ Builder.CreateUIToFP(ScanInfo.Span, Builder.getDoubleTy());
+ llvm::Value *LogVal = emitNoUnwindRuntimeCall(F, Arg, "");
+ F = llvm::Intrinsic::getOrInsertDeclaration(
+ Builder.GetInsertBlock()->getModule(),
+ (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Builder.getDoubleTy());
+ LogVal = emitNoUnwindRuntimeCall(F, LogVal, "");
+ LogVal = Builder.CreateFPToUI(LogVal, Builder.getInt32Ty());
+ llvm::Value *NMin1 = Builder.CreateNUWSub(
+ ScanInfo.Span, llvm::ConstantInt::get(ScanInfo.Span->getType(), 1));
+ Builder.SetInsertPoint(InputBB);
+ Builder.CreateBr(LoopBB);
+ emitBlock(LoopBB, CurFn);
+ Builder.SetInsertPoint(LoopBB);
+
+ PHINode *Counter = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+ //// size pow2k = 1;
+ PHINode *Pow2K = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+ Counter->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 0),
+ InputBB);
+ Pow2K->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 1),
+ InputBB);
+ //// for (size i = n - 1; i >= 2 ^ k; --i)
+ //// tmp[i] op= tmp[i-pow2k];
+ llvm::BasicBlock *InnerLoopBB =
+ BasicBlock::Create(CurFn->getContext(), "omp.inner.log.scan.body");
+ llvm::BasicBlock *InnerExitBB =
+ BasicBlock::Create(CurFn->getContext(), "omp.inner.log.scan.exit");
+ llvm::Value *CmpI = Builder.CreateICmpUGE(NMin1, Pow2K);
+ Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
+ emitBlock(InnerLoopBB, CurFn);
+ Builder.SetInsertPoint(InnerLoopBB);
+ auto *IVal = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+ IVal->addIncoming(NMin1, LoopBB);
+ for (ReductionInfo RedInfo : ReductionInfos) {
+ Value *ReductionVal = RedInfo.PrivateVariable;
+ Value *BuffPtr = ScanInfo.ScanBuffPtrs[ReductionVal];
+ Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr);
+ Type *DestTy = RedInfo.ElementType;
+ Value *IV = Builder.CreateAdd(IVal, Builder.getInt32(1));
+ Value *LHSPtr =
+ Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset");
+ Value *OffsetIval = Builder.CreateNUWSub(IV, Pow2K);
+ Value *RHSPtr =
+ Builder.CreateInBoundsGEP(DestTy, Buff, OffsetIval, "arrayOffset");
+ Value *LHS = Builder.CreateLoad(DestTy, LHSPtr);
+ Value *RHS = Builder.CreateLoad(DestTy, RHSPtr);
+ llvm::Value *Result;
+ InsertPointOrErrorTy AfterIP =
+ RedInfo.ReductionGen(Builder.saveIP(), LHS, RHS, Result);
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.CreateStore(Result, LHSPtr);
+ }
+ llvm::Value *NextIVal = Builder.CreateNUWSub(
+ IVal, llvm::ConstantInt::get(Builder.getInt32Ty(), 1));
+ IVal->addIncoming(NextIVal, Builder.GetInsertBlock());
+ CmpI = Builder.CreateICmpUGE(NextIVal, Pow2K);
+ Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
+ emitBlock(InnerExitBB, CurFn);
+ llvm::Value *Next = Builder.CreateNUWAdd(
+ Counter, llvm::ConstantInt::get(Counter->getType(), 1));
+ Counter->addIncoming(Next, Builder.GetInsertBlock());
+ // pow2k <<= 1;
+ llvm::Value *NextPow2K = Builder.CreateShl(Pow2K, 1, "", /*HasNUW=*/true);
+ Pow2K->addIncoming(NextPow2K, Builder.GetInsertBlock());
+ llvm::Value *Cmp = Builder.CreateICmpNE(Next, LogVal);
+ Builder.CreateCondBr(Cmp, LoopBB, ExitBB);
+ Builder.SetInsertPoint(ExitBB->getFirstInsertionPt());
+ return Error::success();
+ };
+
+ // TODO: Perform finalization actions for variables. This has to be
+ // called for variables which have destructors/finalizers.
+ auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
+
+ llvm::Value *FilterVal = Builder.getInt32(0);
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+ createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal);
+
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.restoreIP(*AfterIP);
+ AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
+
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.restoreIP(*AfterIP);
+ Error Err = emitScanBasedDirectiveFinalsIR(ReductionInfos);
+ if (Err) {
+ return Err;
+ }
----------------
anchuraj wrote:
Updated.
https://github.com/llvm/llvm-project/pull/136035
More information about the llvm-commits
mailing list