[llvm] [OpenMP] [IR Builder] Changes to Support Scan Operation (PR #136035)

Anchu Rajendran S via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 16 11:33:03 PDT 2025


================
@@ -4011,6 +4013,336 @@ OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
                               /*Conditional*/ true, /*hasFinalize*/ true);
 }
 
+llvm::CallInst *
+OpenMPIRBuilder::emitNoUnwindRuntimeCall(llvm::FunctionCallee Callee,
+                                         ArrayRef<llvm::Value *> Args,
+                                         const llvm::Twine &Name) {
+  llvm::CallInst *Call = Builder.CreateCall(
+      Callee, Args, SmallVector<llvm::OperandBundleDef, 1>(), Name);
+  Call->setDoesNotThrow();
+  return Call;
+}
+
+// Expects input basic block is dominated by BeforeScanBB.
+// Once Scan directive is encountered, the code after scan directive should be
+// dominated by AfterScanBB. Scan directive splits the code sequence to
+// scan and input phase. Based on whether inclusive or exclusive
+// clause is used in the scan directive and whether input loop or scan loop
+// is lowered, it adds jumps to input and scan phase. First Scan loop is the
+// input loop and second is the scan loop. The code generated handles only
+// inclusive scans now.
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
+    const LocationDescription &Loc, InsertPointTy AllocaIP,
+    ArrayRef<llvm::Value *> ScanVars, ArrayRef<llvm::Type *> ScanVarsType,
+    bool IsInclusive) {
+  if (ScanInfo.OMPFirstScanLoop) {
+    llvm::Error Err =
+        emitScanBasedDirectiveDeclsIR(AllocaIP, ScanVars, ScanVarsType);
+    if (Err) {
+      return Err;
+    }
+  }
+  if (!updateToLocation(Loc))
+    return Loc.IP;
+
+  llvm::Value *IV = ScanInfo.IV;
+
+  if (ScanInfo.OMPFirstScanLoop) {
+    // Emit buffer[i] = red; at the end of the input phase.
+    for (size_t i = 0; i < ScanVars.size(); i++) {
+      Value *BuffPtr = ScanInfo.ScanBuffPtrs[ScanVars[i]];
+      Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr);
+      Type *DestTy = ScanVarsType[i];
+      Value *Val = Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset");
+      Value *Src = Builder.CreateLoad(DestTy, ScanVars[i]);
+
+      Builder.CreateStore(Src, Val);
+    }
+  }
+  Builder.CreateBr(ScanInfo.OMPScanLoopExit);
+  emitBlock(ScanInfo.OMPScanDispatch, Builder.GetInsertBlock()->getParent());
+
+  if (!ScanInfo.OMPFirstScanLoop) {
+    IV = ScanInfo.IV;
+    // Emit red = buffer[i]; at the entrance to the scan phase.
+    // TODO: if exclusive scan, the red = buffer[i-1] needs to be updated.
+    for (size_t i = 0; i < ScanVars.size(); i++) {
+      Value *BuffPtr = ScanInfo.ScanBuffPtrs[ScanVars[i]];
+      Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr);
+      Type *DestTy = ScanVarsType[i];
+      Value *SrcPtr =
+          Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset");
+      Value *Src = Builder.CreateLoad(DestTy, SrcPtr);
+      Builder.CreateStore(Src, ScanVars[i]);
+    }
+  }
+
+  // TODO: Update it to CreateBr and remove dead blocks
+  llvm::Value *CmpI = Builder.getInt1(true);
+  if (ScanInfo.OMPFirstScanLoop == IsInclusive) {
+    Builder.CreateCondBr(CmpI, ScanInfo.OMPBeforeScanBlock,
+                         ScanInfo.OMPAfterScanBlock);
+  } else {
+    Builder.CreateCondBr(CmpI, ScanInfo.OMPAfterScanBlock,
+                         ScanInfo.OMPBeforeScanBlock);
+  }
+  emitBlock(ScanInfo.OMPAfterScanBlock, Builder.GetInsertBlock()->getParent());
+  Builder.SetInsertPoint(ScanInfo.OMPAfterScanBlock);
+  return Builder.saveIP();
+}
+
+Error OpenMPIRBuilder::emitScanBasedDirectiveDeclsIR(
+    InsertPointTy AllocaIP, ArrayRef<Value *> ScanVars,
+    ArrayRef<Type *> ScanVarsType) {
+
+  Builder.restoreIP(AllocaIP);
+  // Create the shared pointer at alloca IP.
+  for (size_t i = 0; i < ScanVars.size(); i++) {
+    llvm::Value *BuffPtr =
+        Builder.CreateAlloca(Builder.getPtrTy(), nullptr, "vla");
+    ScanInfo.ScanBuffPtrs[ScanVars[i]] = BuffPtr;
+  }
+
+  // Allocate temporary buffer by master thread
+  auto BodyGenCB = [&](InsertPointTy AllocaIP,
+                       InsertPointTy CodeGenIP) -> Error {
+    Builder.restoreIP(CodeGenIP);
+    Value *AllocSpan = Builder.CreateAdd(ScanInfo.Span, Builder.getInt32(1));
+    for (size_t i = 0; i < ScanVars.size(); i++) {
+      Type *IntPtrTy = Builder.getInt32Ty();
+      Constant *Allocsize = ConstantExpr::getSizeOf(ScanVarsType[i]);
+      Allocsize = ConstantExpr::getTruncOrBitCast(Allocsize, IntPtrTy);
+      Value *Buff = Builder.CreateMalloc(IntPtrTy, ScanVarsType[i], Allocsize,
+                                         AllocSpan, nullptr, "arr");
+      Builder.CreateStore(Buff, ScanInfo.ScanBuffPtrs[ScanVars[i]]);
+    }
+    return Error::success();
+  };
+  // TODO: Perform finalization actions for variables. This has to be
+  // called for variables which have destructors/finalizers.
+  auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
+
+  Builder.SetInsertPoint(ScanInfo.OMPScanInit->getTerminator());
+  llvm::Value *FilterVal = Builder.getInt32(0);
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+      createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal);
+
+  if (!AfterIP)
+    return AfterIP.takeError();
+  Builder.restoreIP(*AfterIP);
+  BasicBlock *InputBB = Builder.GetInsertBlock();
+  if (InputBB->getTerminator())
+    Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
+  AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
+  if (!AfterIP)
+    return AfterIP.takeError();
+  Builder.restoreIP(*AfterIP);
+
+  return Error::success();
+}
+
+Error OpenMPIRBuilder::emitScanBasedDirectiveFinalsIR(
+    SmallVector<ReductionInfo> ReductionInfos) {
+  auto BodyGenCB = [&](InsertPointTy AllocaIP,
+                       InsertPointTy CodeGenIP) -> Error {
+    Builder.restoreIP(CodeGenIP);
+    for (ReductionInfo RedInfo : ReductionInfos) {
+      Value *PrivateVar = RedInfo.PrivateVariable;
+      Value *OrigVar = RedInfo.Variable;
+      Value *BuffPtr = ScanInfo.ScanBuffPtrs[PrivateVar];
+      Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr);
+
+      Type *SrcTy = RedInfo.ElementType;
+      Value *Val =
+          Builder.CreateInBoundsGEP(SrcTy, Buff, ScanInfo.Span, "arrayOffset");
+      Value *Src = Builder.CreateLoad(SrcTy, Val);
+
+      Builder.CreateStore(Src, OrigVar);
+      Builder.CreateFree(Buff);
+    }
+    return Error::success();
+  };
+  // TODO: Perform finalization actions for variables. This has to be
+  // called for variables which have destructors/finalizers.
+  auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
+
+  if (ScanInfo.OMPScanFinish->getTerminator())
+    Builder.SetInsertPoint(ScanInfo.OMPScanFinish->getTerminator());
+  else
+    Builder.SetInsertPoint(ScanInfo.OMPScanFinish);
+
+  llvm::Value *FilterVal = Builder.getInt32(0);
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+      createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal);
+
+  if (!AfterIP)
+    return AfterIP.takeError();
+  Builder.restoreIP(*AfterIP);
+  BasicBlock *InputBB = Builder.GetInsertBlock();
+  if (InputBB->getTerminator())
+    Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
+  AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
+  if (!AfterIP)
+    return AfterIP.takeError();
+  Builder.restoreIP(*AfterIP);
+  return Error::success();
+}
+
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
+    const LocationDescription &Loc,
+    SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos) {
+
+  if (!updateToLocation(Loc))
+    return Loc.IP;
+  auto BodyGenCB = [&](InsertPointTy AllocaIP,
+                       InsertPointTy CodeGenIP) -> Error {
+    Builder.restoreIP(CodeGenIP);
+    Function *CurFn = Builder.GetInsertBlock()->getParent();
+    // for (int k = 0; k <= ceil(log2(n)); ++k)
+    llvm::BasicBlock *LoopBB =
+        BasicBlock::Create(CurFn->getContext(), "omp.outer.log.scan.body");
+    llvm::BasicBlock *ExitBB =
+        splitBB(Builder, false, "omp.outer.log.scan.exit");
+    llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration(
+        Builder.GetInsertBlock()->getModule(),
+        (llvm::Intrinsic::ID)llvm::Intrinsic::log2, Builder.getDoubleTy());
+    llvm::BasicBlock *InputBB = Builder.GetInsertBlock();
+    llvm::Value *Arg =
+        Builder.CreateUIToFP(ScanInfo.Span, Builder.getDoubleTy());
+    llvm::Value *LogVal = emitNoUnwindRuntimeCall(F, Arg, "");
+    F = llvm::Intrinsic::getOrInsertDeclaration(
+        Builder.GetInsertBlock()->getModule(),
+        (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Builder.getDoubleTy());
+    LogVal = emitNoUnwindRuntimeCall(F, LogVal, "");
+    LogVal = Builder.CreateFPToUI(LogVal, Builder.getInt32Ty());
+    llvm::Value *NMin1 = Builder.CreateNUWSub(
+        ScanInfo.Span, llvm::ConstantInt::get(ScanInfo.Span->getType(), 1));
+    Builder.SetInsertPoint(InputBB);
+    Builder.CreateBr(LoopBB);
+    emitBlock(LoopBB, CurFn);
+    Builder.SetInsertPoint(LoopBB);
+
+    PHINode *Counter = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+    //// size pow2k = 1;
+    PHINode *Pow2K = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+    Counter->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 0),
+                         InputBB);
+    Pow2K->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 1),
+                       InputBB);
+    //// for (size i = n - 1; i >= 2 ^ k; --i)
+    ////   tmp[i] op= tmp[i-pow2k];
+    llvm::BasicBlock *InnerLoopBB =
+        BasicBlock::Create(CurFn->getContext(), "omp.inner.log.scan.body");
+    llvm::BasicBlock *InnerExitBB =
+        BasicBlock::Create(CurFn->getContext(), "omp.inner.log.scan.exit");
+    llvm::Value *CmpI = Builder.CreateICmpUGE(NMin1, Pow2K);
+    Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
+    emitBlock(InnerLoopBB, CurFn);
+    Builder.SetInsertPoint(InnerLoopBB);
+    auto *IVal = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+    IVal->addIncoming(NMin1, LoopBB);
+    for (ReductionInfo RedInfo : ReductionInfos) {
+      Value *ReductionVal = RedInfo.PrivateVariable;
+      Value *BuffPtr = ScanInfo.ScanBuffPtrs[ReductionVal];
+      Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr);
+      Type *DestTy = RedInfo.ElementType;
+      Value *IV = Builder.CreateAdd(IVal, Builder.getInt32(1));
+      Value *LHSPtr =
+          Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset");
+      Value *OffsetIval = Builder.CreateNUWSub(IV, Pow2K);
+      Value *RHSPtr =
+          Builder.CreateInBoundsGEP(DestTy, Buff, OffsetIval, "arrayOffset");
+      Value *LHS = Builder.CreateLoad(DestTy, LHSPtr);
+      Value *RHS = Builder.CreateLoad(DestTy, RHSPtr);
+      llvm::Value *Result;
+      InsertPointOrErrorTy AfterIP =
+          RedInfo.ReductionGen(Builder.saveIP(), LHS, RHS, Result);
+      if (!AfterIP)
+        return AfterIP.takeError();
+      Builder.CreateStore(Result, LHSPtr);
+    }
+    llvm::Value *NextIVal = Builder.CreateNUWSub(
+        IVal, llvm::ConstantInt::get(Builder.getInt32Ty(), 1));
+    IVal->addIncoming(NextIVal, Builder.GetInsertBlock());
+    CmpI = Builder.CreateICmpUGE(NextIVal, Pow2K);
+    Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
+    emitBlock(InnerExitBB, CurFn);
+    llvm::Value *Next = Builder.CreateNUWAdd(
+        Counter, llvm::ConstantInt::get(Counter->getType(), 1));
+    Counter->addIncoming(Next, Builder.GetInsertBlock());
+    // pow2k <<= 1;
+    llvm::Value *NextPow2K = Builder.CreateShl(Pow2K, 1, "", /*HasNUW=*/true);
+    Pow2K->addIncoming(NextPow2K, Builder.GetInsertBlock());
+    llvm::Value *Cmp = Builder.CreateICmpNE(Next, LogVal);
+    Builder.CreateCondBr(Cmp, LoopBB, ExitBB);
+    Builder.SetInsertPoint(ExitBB->getFirstInsertionPt());
+    return Error::success();
+  };
+
+  // TODO: Perform finalization actions for variables. This has to be
+  // called for variables which have destructors/finalizers.
+  auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
+
+  llvm::Value *FilterVal = Builder.getInt32(0);
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+      createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal);
+
+  if (!AfterIP)
+    return AfterIP.takeError();
+  Builder.restoreIP(*AfterIP);
+  AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
+
+  if (!AfterIP)
+    return AfterIP.takeError();
+  Builder.restoreIP(*AfterIP);
+  Error Err = emitScanBasedDirectiveFinalsIR(ReductionInfos);
+  if (Err) {
+    return Err;
+  }
----------------
anchuraj wrote:

Updated.

https://github.com/llvm/llvm-project/pull/136035


More information about the llvm-commits mailing list