[llvm] [OpenMP] [IR Builder] Changes to Support Scan Operation (PR #136035)
Anchu Rajendran S via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 18 13:41:01 PDT 2025
https://github.com/anchuraj updated https://github.com/llvm/llvm-project/pull/136035
>From 6d59b7ec3b52dfddf50131c81f44af8eb9c0c84a Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Wed, 16 Apr 2025 16:03:02 -0500
Subject: [PATCH] IR Builder Changes to Support Scan Operation
---
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 139 +++++-
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 423 +++++++++++++++++-
.../Frontend/OpenMPIRBuilderTest.cpp | 95 ++++
3 files changed, 652 insertions(+), 5 deletions(-)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 6b104708bdb0d..682e1699078af 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -503,6 +503,31 @@ class OpenMPIRBuilder {
return allocaInst;
}
};
+
+ struct ScanInformation {
+ public:
+ /// Dominates the body of the loop before scan directive
+ llvm::BasicBlock *OMPBeforeScanBlock = nullptr;
+ /// Dominates the body of the loop before scan directive
+ llvm::BasicBlock *OMPAfterScanBlock = nullptr;
+ /// Controls the flow to before or after scan blocks
+ llvm::BasicBlock *OMPScanDispatch = nullptr;
+ /// Exit block of loop body
+ llvm::BasicBlock *OMPScanLoopExit = nullptr;
+ /// Block before loop body where scan initializations are done
+ llvm::BasicBlock *OMPScanInit = nullptr;
+ /// Block after loop body where scan finalizations are done
+ llvm::BasicBlock *OMPScanFinish = nullptr;
+ /// If true, it indicates Input phase is lowered; else it indicates
+ /// ScanPhase is lowered
+ bool OMPFirstScanLoop = false;
+ // Maps the private reduction variable to the pointer of the temporary
+ // buffer
+ llvm::SmallDenseMap<llvm::Value *, llvm::Value *> ScanBuffPtrs;
+ llvm::Value *IV;
+ llvm::Value *Span;
+ } ScanInfo;
+
/// Initialize the internal state, this will put structures types and
/// potentially other helpers into the underlying module. Must be called
/// before any other method and only once! This internal state includes types
@@ -729,6 +754,35 @@ class OpenMPIRBuilder {
LoopBodyGenCallbackTy BodyGenCB, Value *TripCount,
const Twine &Name = "loop");
+ /// Generator for the control flow structure of an OpenMP canonical loops if
+ /// the parent directive has an `inscan` modifier specified.
+ /// If the `inscan` modifier is specified, the region of the parent is
+ /// expected to have a `scan` directive. Based on the clauses in
+ /// scan directive, the body of the loop is split into two loops: Input loop
+ /// and Scan Loop. Input loop contains the code generated for input phase of
+ /// scan and Scan loop contains the code generated for scan phase of scan.
+ ///
+ /// \param Loc The insert and source location description.
+ /// \param BodyGenCB Callback that will generate the loop body code.
+ /// \param Start Value of the loop counter for the first iterations.
+ /// \param Stop Loop counter values past this will stop the loop.
+ /// \param Step Loop counter increment after each iteration; negative
+ /// means counting down.
+ /// \param IsSigned Whether Start, Stop and Step are signed integers.
+ /// \param InclusiveStop Whether \p Stop itself is a valid value for the loop
+ /// counter.
+ /// \param ComputeIP Insertion point for instructions computing the trip
+ /// count. Can be used to ensure the trip count is available
+ /// at the outermost loop of a loop nest. If not set,
+ /// defaults to the preheader of the generated loop.
+ /// \param Name Base name used to derive BB and instruction names.
+ ///
+ /// \returns A vector containing Loop Info of Input Loop and Scan Loop.
+ Expected<SmallVector<llvm::CanonicalLoopInfo *>> createCanonicalScanLoops(
+ const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
+ Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
+ InsertPointTy ComputeIP, const Twine &Name);
+
/// Calculate the trip count of a canonical loop.
///
/// This allows specifying user-defined loop counter values using increment,
@@ -798,13 +852,16 @@ class OpenMPIRBuilder {
/// at the outermost loop of a loop nest. If not set,
/// defaults to the preheader of the generated loop.
/// \param Name Base name used to derive BB and instruction names.
+ /// \param InScan Whether loop has a scan reduction specified.
///
/// \returns An object representing the created control flow structure which
/// can be used for loop-associated directives.
- Expected<CanonicalLoopInfo *> createCanonicalLoop(
- const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
- Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
- InsertPointTy ComputeIP = {}, const Twine &Name = "loop");
+ Expected<CanonicalLoopInfo *>
+ createCanonicalLoop(const LocationDescription &Loc,
+ LoopBodyGenCallbackTy BodyGenCB, Value *Start,
+ Value *Stop, Value *Step, bool IsSigned,
+ bool InclusiveStop, InsertPointTy ComputeIP = {},
+ const Twine &Name = "loop", bool InScan = false);
/// Collapse a loop nest into a single loop.
///
@@ -1532,6 +1589,45 @@ class OpenMPIRBuilder {
ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
Function *ReduceFn, AttributeList FuncAttrs);
+ /// Creates the runtime call specified
+ /// \param Callee Function Declaration Value
+ /// \param Args Arguments passed to the call
+ /// \param Name Optional param to specify the name of the call Instruction.
+ ///
+ /// \return The Runtime call instruction created.
+ llvm::CallInst *emitNoUnwindRuntimeCall(llvm::FunctionCallee Callee,
+ ArrayRef<llvm::Value *> Args,
+ const llvm::Twine &Name);
+
+ /// Helper function for CreateCanonicalScanLoops to create InputLoop
+ /// in the firstGen and Scan Loop in the SecondGen
+ /// \param InputLoopGen Callback for generating the loop for input phase
+ /// \param ScanLoopGen Callback for generating the loop for scan phase
+ ///
+ /// \return error if any produced, else return success.
+ Error emitScanBasedDirectiveIR(
+ llvm::function_ref<Error()> InputLoopGen,
+ llvm::function_ref<Error(LocationDescription Loc)> ScanLoopGen);
+
+ /// Creates the basic blocks required for scan reduction.
+ void createScanBBs();
+
+ /// Dynamically allocates the buffer needed for scan reduction.
+ /// \param AllocaIP The IP where possibly-shared pointer of buffer needs to be
+ /// declared. \param ScanVars Scan Variables.
+ ///
+ /// \return error if any produced, else return success.
+ Error emitScanBasedDirectiveDeclsIR(InsertPointTy AllocaIP,
+ ArrayRef<llvm::Value *> ScanVars,
+ ArrayRef<llvm::Type *> ScanVarsType);
+
+ /// Copies the result back to the reduction variable.
+ /// \param ReductionInfos Array type containing the ReductionOps.
+ ///
+ /// \return error if any produced, else return success.
+ Error emitScanBasedDirectiveFinalsIR(
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos);
+
/// This function emits a helper that gathers Reduce lists from the first
/// lane of every active warp to lanes in the first warp.
///
@@ -2607,6 +2703,41 @@ class OpenMPIRBuilder {
BodyGenCallbackTy BodyGenCB,
FinalizeCallbackTy FiniCB, Value *Filter);
+ /// This function performs the scan reduction of the values updated in
+ /// the input phase. The reduction logic needs to be emitted between input
+ /// and scan loop returned by `CreateCanonicalScanLoops`. The following
+ /// is the code that is generated, `buffer` and `span` are expected to be
+ /// populated before executing the generated code.
+ ///
+ /// for (int k = 0; k != ceil(log2(span)); ++k) {
+ /// i=pow(2,k)
+ /// for (size cnt = last_iter; cnt >= i; --cnt)
+ /// buffer[cnt] op= buffer[cnt-i];
+ /// }
+ /// \param Loc The insert and source location description.
+ /// \param ReductionInfos Array type containing the ReductionOps.
+ ///
+ /// \returns The insertion position *after* the masked.
+ InsertPointOrErrorTy emitScanReduction(
+ const LocationDescription &Loc,
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos);
+
+ /// This directive split and directs the control flow to input phase
+ /// blocks or scan phase blocks based on 1. whether input loop or scan loop
+ /// is executed, 2. whether exclusive or inclusive scan is used.
+ ///
+ /// \param Loc The insert and source location description.
+ /// \param AllocaIP The IP where the temporary buffer for scan reduction
+ // needs to be allocated.
+ /// \param ScanVars Scan Variables.
+ /// \param IsInclusive Whether it is an inclusive or exclusive scan.
+ ///
+ /// \returns The insertion position *after* the scan.
+ InsertPointOrErrorTy createScan(const LocationDescription &Loc,
+ InsertPointTy AllocaIP,
+ ArrayRef<llvm::Value *> ScanVars,
+ ArrayRef<llvm::Type *> ScanVarsType,
+ bool IsInclusive);
/// Generator for '#omp critical'
///
/// \param Loc The insert and source location description.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 28662efc02882..8fa91f3f9315b 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -59,6 +59,8 @@
#include "llvm/Transforms/Utils/LoopPeel.h"
#include "llvm/Transforms/Utils/UnrollLoop.h"
+#include <cassert>
+#include <cstddef>
#include <cstdint>
#include <optional>
@@ -3981,6 +3983,336 @@ OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
/*Conditional*/ true, /*hasFinalize*/ true);
}
+llvm::CallInst *
+OpenMPIRBuilder::emitNoUnwindRuntimeCall(llvm::FunctionCallee Callee,
+ ArrayRef<llvm::Value *> Args,
+ const llvm::Twine &Name) {
+ llvm::CallInst *Call = Builder.CreateCall(
+ Callee, Args, SmallVector<llvm::OperandBundleDef, 1>(), Name);
+ Call->setDoesNotThrow();
+ return Call;
+}
+
+// Expects input basic block is dominated by BeforeScanBB.
+// Once Scan directive is encountered, the code after scan directive should be
+// dominated by AfterScanBB. Scan directive splits the code sequence to
+// scan and input phase. Based on whether inclusive or exclusive
+// clause is used in the scan directive and whether input loop or scan loop
+// is lowered, it adds jumps to input and scan phase. First Scan loop is the
+// input loop and second is the scan loop. The code generated handles only
+// inclusive scans now.
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
+ const LocationDescription &Loc, InsertPointTy AllocaIP,
+ ArrayRef<llvm::Value *> ScanVars, ArrayRef<llvm::Type *> ScanVarsType,
+ bool IsInclusive) {
+ if (ScanInfo.OMPFirstScanLoop) {
+ llvm::Error Err =
+ emitScanBasedDirectiveDeclsIR(AllocaIP, ScanVars, ScanVarsType);
+ if (Err) {
+ return Err;
+ }
+ }
+ if (!updateToLocation(Loc))
+ return Loc.IP;
+
+ llvm::Value *IV = ScanInfo.IV;
+
+ if (ScanInfo.OMPFirstScanLoop) {
+ // Emit buffer[i] = red; at the end of the input phase.
+ for (size_t i = 0; i < ScanVars.size(); i++) {
+ Value *BuffPtr = ScanInfo.ScanBuffPtrs[ScanVars[i]];
+ Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr);
+ Type *DestTy = ScanVarsType[i];
+ Value *Val = Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset");
+ Value *Src = Builder.CreateLoad(DestTy, ScanVars[i]);
+
+ Builder.CreateStore(Src, Val);
+ }
+ }
+ Builder.CreateBr(ScanInfo.OMPScanLoopExit);
+ emitBlock(ScanInfo.OMPScanDispatch, Builder.GetInsertBlock()->getParent());
+
+ if (!ScanInfo.OMPFirstScanLoop) {
+ IV = ScanInfo.IV;
+ // Emit red = buffer[i]; at the entrance to the scan phase.
+ // TODO: if exclusive scan, the red = buffer[i-1] needs to be updated.
+ for (size_t i = 0; i < ScanVars.size(); i++) {
+ Value *BuffPtr = ScanInfo.ScanBuffPtrs[ScanVars[i]];
+ Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr);
+ Type *DestTy = ScanVarsType[i];
+ Value *SrcPtr =
+ Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset");
+ Value *Src = Builder.CreateLoad(DestTy, SrcPtr);
+ Builder.CreateStore(Src, ScanVars[i]);
+ }
+ }
+
+ // TODO: Update it to CreateBr and remove dead blocks
+ llvm::Value *CmpI = Builder.getInt1(true);
+ if (ScanInfo.OMPFirstScanLoop == IsInclusive) {
+ Builder.CreateCondBr(CmpI, ScanInfo.OMPBeforeScanBlock,
+ ScanInfo.OMPAfterScanBlock);
+ } else {
+ Builder.CreateCondBr(CmpI, ScanInfo.OMPAfterScanBlock,
+ ScanInfo.OMPBeforeScanBlock);
+ }
+ emitBlock(ScanInfo.OMPAfterScanBlock, Builder.GetInsertBlock()->getParent());
+ Builder.SetInsertPoint(ScanInfo.OMPAfterScanBlock);
+ return Builder.saveIP();
+}
+
+Error OpenMPIRBuilder::emitScanBasedDirectiveDeclsIR(
+ InsertPointTy AllocaIP, ArrayRef<Value *> ScanVars,
+ ArrayRef<Type *> ScanVarsType) {
+
+ Builder.restoreIP(AllocaIP);
+ // Create the shared pointer at alloca IP.
+ for (size_t i = 0; i < ScanVars.size(); i++) {
+ llvm::Value *BuffPtr =
+ Builder.CreateAlloca(Builder.getPtrTy(), nullptr, "vla");
+ ScanInfo.ScanBuffPtrs[ScanVars[i]] = BuffPtr;
+ }
+
+ // Allocate temporary buffer by master thread
+ auto BodyGenCB = [&](InsertPointTy AllocaIP,
+ InsertPointTy CodeGenIP) -> Error {
+ Builder.restoreIP(CodeGenIP);
+ Value *AllocSpan = Builder.CreateAdd(ScanInfo.Span, Builder.getInt32(1));
+ for (size_t i = 0; i < ScanVars.size(); i++) {
+ Type *IntPtrTy = Builder.getInt32Ty();
+ Constant *Allocsize = ConstantExpr::getSizeOf(ScanVarsType[i]);
+ Allocsize = ConstantExpr::getTruncOrBitCast(Allocsize, IntPtrTy);
+ Value *Buff = Builder.CreateMalloc(IntPtrTy, ScanVarsType[i], Allocsize,
+ AllocSpan, nullptr, "arr");
+ Builder.CreateStore(Buff, ScanInfo.ScanBuffPtrs[ScanVars[i]]);
+ }
+ return Error::success();
+ };
+ // TODO: Perform finalization actions for variables. This has to be
+ // called for variables which have destructors/finalizers.
+ auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
+
+ Builder.SetInsertPoint(ScanInfo.OMPScanInit->getTerminator());
+ llvm::Value *FilterVal = Builder.getInt32(0);
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+ createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal);
+
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.restoreIP(*AfterIP);
+ BasicBlock *InputBB = Builder.GetInsertBlock();
+ if (InputBB->getTerminator())
+ Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
+ AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.restoreIP(*AfterIP);
+
+ return Error::success();
+}
+
+Error OpenMPIRBuilder::emitScanBasedDirectiveFinalsIR(
+ SmallVector<ReductionInfo> ReductionInfos) {
+ auto BodyGenCB = [&](InsertPointTy AllocaIP,
+ InsertPointTy CodeGenIP) -> Error {
+ Builder.restoreIP(CodeGenIP);
+ for (ReductionInfo RedInfo : ReductionInfos) {
+ Value *PrivateVar = RedInfo.PrivateVariable;
+ Value *OrigVar = RedInfo.Variable;
+ Value *BuffPtr = ScanInfo.ScanBuffPtrs[PrivateVar];
+ Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr);
+
+ Type *SrcTy = RedInfo.ElementType;
+ Value *Val =
+ Builder.CreateInBoundsGEP(SrcTy, Buff, ScanInfo.Span, "arrayOffset");
+ Value *Src = Builder.CreateLoad(SrcTy, Val);
+
+ Builder.CreateStore(Src, OrigVar);
+ Builder.CreateFree(Buff);
+ }
+ return Error::success();
+ };
+ // TODO: Perform finalization actions for variables. This has to be
+ // called for variables which have destructors/finalizers.
+ auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
+
+ if (ScanInfo.OMPScanFinish->getTerminator())
+ Builder.SetInsertPoint(ScanInfo.OMPScanFinish->getTerminator());
+ else
+ Builder.SetInsertPoint(ScanInfo.OMPScanFinish);
+
+ llvm::Value *FilterVal = Builder.getInt32(0);
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+ createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal);
+
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.restoreIP(*AfterIP);
+ BasicBlock *InputBB = Builder.GetInsertBlock();
+ if (InputBB->getTerminator())
+ Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
+ AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.restoreIP(*AfterIP);
+ return Error::success();
+}
+
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
+ const LocationDescription &Loc,
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos) {
+
+ if (!updateToLocation(Loc))
+ return Loc.IP;
+ auto BodyGenCB = [&](InsertPointTy AllocaIP,
+ InsertPointTy CodeGenIP) -> Error {
+ Builder.restoreIP(CodeGenIP);
+ Function *CurFn = Builder.GetInsertBlock()->getParent();
+ // for (int k = 0; k <= ceil(log2(n)); ++k)
+ llvm::BasicBlock *LoopBB =
+ BasicBlock::Create(CurFn->getContext(), "omp.outer.log.scan.body");
+ llvm::BasicBlock *ExitBB =
+ splitBB(Builder, false, "omp.outer.log.scan.exit");
+ llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration(
+ Builder.GetInsertBlock()->getModule(),
+ (llvm::Intrinsic::ID)llvm::Intrinsic::log2, Builder.getDoubleTy());
+ llvm::BasicBlock *InputBB = Builder.GetInsertBlock();
+ llvm::Value *Arg =
+ Builder.CreateUIToFP(ScanInfo.Span, Builder.getDoubleTy());
+ llvm::Value *LogVal = emitNoUnwindRuntimeCall(F, Arg, "");
+ F = llvm::Intrinsic::getOrInsertDeclaration(
+ Builder.GetInsertBlock()->getModule(),
+ (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Builder.getDoubleTy());
+ LogVal = emitNoUnwindRuntimeCall(F, LogVal, "");
+ LogVal = Builder.CreateFPToUI(LogVal, Builder.getInt32Ty());
+ llvm::Value *NMin1 = Builder.CreateNUWSub(
+ ScanInfo.Span, llvm::ConstantInt::get(ScanInfo.Span->getType(), 1));
+ Builder.SetInsertPoint(InputBB);
+ Builder.CreateBr(LoopBB);
+ emitBlock(LoopBB, CurFn);
+ Builder.SetInsertPoint(LoopBB);
+
+ PHINode *Counter = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+ //// size pow2k = 1;
+ PHINode *Pow2K = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+ Counter->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 0),
+ InputBB);
+ Pow2K->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 1),
+ InputBB);
+ //// for (size i = n - 1; i >= 2 ^ k; --i)
+ //// tmp[i] op= tmp[i-pow2k];
+ llvm::BasicBlock *InnerLoopBB =
+ BasicBlock::Create(CurFn->getContext(), "omp.inner.log.scan.body");
+ llvm::BasicBlock *InnerExitBB =
+ BasicBlock::Create(CurFn->getContext(), "omp.inner.log.scan.exit");
+ llvm::Value *CmpI = Builder.CreateICmpUGE(NMin1, Pow2K);
+ Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
+ emitBlock(InnerLoopBB, CurFn);
+ Builder.SetInsertPoint(InnerLoopBB);
+ auto *IVal = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+ IVal->addIncoming(NMin1, LoopBB);
+ for (ReductionInfo RedInfo : ReductionInfos) {
+ Value *ReductionVal = RedInfo.PrivateVariable;
+ Value *BuffPtr = ScanInfo.ScanBuffPtrs[ReductionVal];
+ Value *Buff = Builder.CreateLoad(Builder.getPtrTy(), BuffPtr);
+ Type *DestTy = RedInfo.ElementType;
+ Value *IV = Builder.CreateAdd(IVal, Builder.getInt32(1));
+ Value *LHSPtr =
+ Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset");
+ Value *OffsetIval = Builder.CreateNUWSub(IV, Pow2K);
+ Value *RHSPtr =
+ Builder.CreateInBoundsGEP(DestTy, Buff, OffsetIval, "arrayOffset");
+ Value *LHS = Builder.CreateLoad(DestTy, LHSPtr);
+ Value *RHS = Builder.CreateLoad(DestTy, RHSPtr);
+ llvm::Value *Result;
+ InsertPointOrErrorTy AfterIP =
+ RedInfo.ReductionGen(Builder.saveIP(), LHS, RHS, Result);
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.CreateStore(Result, LHSPtr);
+ }
+ llvm::Value *NextIVal = Builder.CreateNUWSub(
+ IVal, llvm::ConstantInt::get(Builder.getInt32Ty(), 1));
+ IVal->addIncoming(NextIVal, Builder.GetInsertBlock());
+ CmpI = Builder.CreateICmpUGE(NextIVal, Pow2K);
+ Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
+ emitBlock(InnerExitBB, CurFn);
+ llvm::Value *Next = Builder.CreateNUWAdd(
+ Counter, llvm::ConstantInt::get(Counter->getType(), 1));
+ Counter->addIncoming(Next, Builder.GetInsertBlock());
+ // pow2k <<= 1;
+ llvm::Value *NextPow2K = Builder.CreateShl(Pow2K, 1, "", /*HasNUW=*/true);
+ Pow2K->addIncoming(NextPow2K, Builder.GetInsertBlock());
+ llvm::Value *Cmp = Builder.CreateICmpNE(Next, LogVal);
+ Builder.CreateCondBr(Cmp, LoopBB, ExitBB);
+ Builder.SetInsertPoint(ExitBB->getFirstInsertionPt());
+ return Error::success();
+ };
+
+ // TODO: Perform finalization actions for variables. This has to be
+ // called for variables which have destructors/finalizers.
+ auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
+
+ llvm::Value *FilterVal = Builder.getInt32(0);
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+ createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal);
+
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.restoreIP(*AfterIP);
+ AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
+
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.restoreIP(*AfterIP);
+ Error Err = emitScanBasedDirectiveFinalsIR(ReductionInfos);
+ if (Err) {
+ return Err;
+ }
+
+ return AfterIP;
+}
+
+Error OpenMPIRBuilder::emitScanBasedDirectiveIR(
+ llvm::function_ref<Error()> InputLoopGen,
+ llvm::function_ref<Error(LocationDescription Loc)> ScanLoopGen) {
+
+ {
+ // Emit loop with input phase:
+ // for (i: 0..<num_iters>) {
+ // <input phase>;
+ // buffer[i] = red;
+ // }
+ ScanInfo.OMPFirstScanLoop = true;
+ auto Result = InputLoopGen();
+ if (Result)
+ return Result;
+ }
+ {
+ // Emit loop with scan phase:
+ // for (i: 0..<num_iters>) {
+ // red = buffer[i];
+ // <scan phase>;
+ // }
+ ScanInfo.OMPFirstScanLoop = false;
+ auto Result = ScanLoopGen(Builder.saveIP());
+ if (Result)
+ return Result;
+ }
+ return Error::success();
+}
+
+void OpenMPIRBuilder::createScanBBs() {
+ Function *Fun = Builder.GetInsertBlock()->getParent();
+ ScanInfo.OMPScanDispatch =
+ BasicBlock::Create(Fun->getContext(), "omp.inscan.dispatch");
+ ScanInfo.OMPAfterScanBlock =
+ BasicBlock::Create(Fun->getContext(), "omp.after.scan.bb");
+ ScanInfo.OMPBeforeScanBlock =
+ BasicBlock::Create(Fun->getContext(), "omp.before.scan.bb");
+ ScanInfo.OMPScanLoopExit =
+ BasicBlock::Create(Fun->getContext(), "omp.scan.loop.exit");
+}
CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
BasicBlock *PostInsertBefore, const Twine &Name) {
@@ -4078,6 +4410,92 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
return CL;
}
+Expected<SmallVector<llvm::CanonicalLoopInfo *>>
+OpenMPIRBuilder::createCanonicalScanLoops(
+ const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
+ Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
+ InsertPointTy ComputeIP, const Twine &Name) {
+ LocationDescription ComputeLoc =
+ ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
+ updateToLocation(ComputeLoc);
+
+ Value *TripCount = calculateCanonicalLoopTripCount(
+ ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
+ ScanInfo.Span = TripCount;
+ ScanInfo.OMPScanInit = splitBB(Builder, true, "scan.init");
+ Builder.SetInsertPoint(ScanInfo.OMPScanInit);
+
+ auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
+ /// The control of the loopbody of following structure:
+ ///
+ /// InputBlock
+ /// |
+ /// ContinueBlock
+ ///
+ /// is transformed to:
+ ///
+ /// InputBlock
+ /// |
+ /// OMPScanDispatch
+ ///
+ /// OMPBeforeScanBlock
+ /// |
+ /// OMPScanLoopExit
+ /// |
+ /// ContinueBlock
+ ///
+ /// OMPBeforeScanBlock dominates the control flow of code generated until
+ /// scan directive is encountered and OMPAfterScanBlock dominates the
+ /// control flow of code generated after scan is encountered. The successor
+ /// of OMPScanDispatch can be OMPBeforeScanBlock or OMPAfterScanBlock based
+ /// on 1.whether it is in Input phase or Scan Phase , 2. whether it is an
+ /// exclusive or inclusive scan.
+ Builder.restoreIP(CodeGenIP);
+ ScanInfo.IV = IV;
+ createScanBBs();
+ BasicBlock *InputBlock = Builder.GetInsertBlock();
+ Instruction *Terminator = InputBlock->getTerminator();
+ assert(Terminator->getNumSuccessors() == 1);
+ BasicBlock *ContinueBlock = Terminator->getSuccessor(0);
+ Terminator->setSuccessor(0, ScanInfo.OMPScanDispatch);
+ emitBlock(ScanInfo.OMPBeforeScanBlock,
+ Builder.GetInsertBlock()->getParent());
+ Builder.CreateBr(ScanInfo.OMPScanLoopExit);
+ emitBlock(ScanInfo.OMPScanLoopExit, Builder.GetInsertBlock()->getParent());
+ Builder.CreateBr(ContinueBlock);
+ Builder.SetInsertPoint(ScanInfo.OMPBeforeScanBlock->getFirstInsertionPt());
+ return BodyGenCB(Builder.saveIP(), IV);
+ };
+
+ SmallVector<llvm::CanonicalLoopInfo *> Result;
+ const auto &&InputLoopGen = [&]() -> Error {
+ auto LoopInfo =
+ createCanonicalLoop(Builder.saveIP(), BodyGen, Start, Stop, Step,
+ IsSigned, InclusiveStop, ComputeIP, Name, true);
+ if (!LoopInfo)
+ return LoopInfo.takeError();
+ Result.push_back(*LoopInfo);
+ Builder.restoreIP((*LoopInfo)->getAfterIP());
+ return Error::success();
+ };
+ const auto &&ScanLoopGen = [&](LocationDescription Loc) -> Error {
+ auto LoopInfo =
+ createCanonicalLoop(Loc, BodyGen, Start, Stop, Step, IsSigned,
+ InclusiveStop, ComputeIP, Name, true);
+ if (!LoopInfo)
+ return LoopInfo.takeError();
+ Result.push_back(*LoopInfo);
+ Builder.restoreIP((*LoopInfo)->getAfterIP());
+ ScanInfo.OMPScanFinish = Builder.GetInsertBlock();
+ return Error::success();
+ };
+ Error Err = emitScanBasedDirectiveIR(InputLoopGen, ScanLoopGen);
+ if (Err) {
+ return Err;
+ }
+ return Result;
+}
+
Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
bool IsSigned, bool InclusiveStop, const Twine &Name) {
@@ -4141,7 +4559,7 @@ Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
- InsertPointTy ComputeIP, const Twine &Name) {
+ InsertPointTy ComputeIP, const Twine &Name, bool InScan) {
LocationDescription ComputeLoc =
ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
@@ -4152,6 +4570,9 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
Builder.restoreIP(CodeGenIP);
Value *Span = Builder.CreateMul(IV, Step);
Value *IndVar = Builder.CreateAdd(Span, Start);
+ if (InScan) {
+ ScanInfo.IV = IndVar;
+ }
return BodyGenCB(Builder.saveIP(), IndVar);
};
LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 2d3d318be7ff1..e301c389beb71 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -23,6 +23,7 @@
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include <cstdlib>
#include <optional>
using namespace llvm;
@@ -5336,6 +5337,100 @@ TEST_F(OpenMPIRBuilderTest, CreateReductions) {
EXPECT_TRUE(findGEPZeroOne(ReductionFn->getArg(1), FirstRHS, SecondRHS));
}
+void createScan(llvm::Value *scanVar, llvm::Type *scanType,
+ OpenMPIRBuilder &OMPBuilder, IRBuilder<> &Builder,
+ OpenMPIRBuilder::LocationDescription Loc,
+ OpenMPIRBuilder::InsertPointTy &allocaIP) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ ASSERT_EXPECTED_INIT(
+ InsertPointTy, retIp,
+ OMPBuilder.createScan(Loc, allocaIP, {scanVar}, {scanType}, true));
+ Builder.restoreIP(retIp);
+}
+
+TEST_F(OpenMPIRBuilderTest, ScanReduction) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ IRBuilder<> Builder(BB);
+ OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+ Value *TripCount = F->getArg(0);
+ Type *LCTy = TripCount->getType();
+ Value *StartVal = ConstantInt::get(LCTy, 1);
+ Value *StopVal = ConstantInt::get(LCTy, 100);
+ Value *Step = ConstantInt::get(LCTy, 1);
+ auto AllocaIP = Builder.saveIP();
+
+ llvm::Value *ScanVar = Builder.CreateAlloca(Builder.getFloatTy());
+ llvm::Value *OrigVar = Builder.CreateAlloca(Builder.getFloatTy());
+ unsigned NumBodiesGenerated = 0;
+ auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) {
+ NumBodiesGenerated += 1;
+ Builder.restoreIP(CodeGenIP);
+ createScan(ScanVar, Builder.getFloatTy(), OMPBuilder, Builder, Loc,
+ AllocaIP);
+ return Error::success();
+ };
+ SmallVector<CanonicalLoopInfo *> Loops;
+ ASSERT_EXPECTED_INIT(SmallVector<CanonicalLoopInfo *>, loopsVec,
+ OMPBuilder.createCanonicalScanLoops(
+ Loc, LoopBodyGenCB, StartVal, StopVal, Step, false,
+ false, Builder.saveIP(), "scan"));
+ Loops = loopsVec;
+ EXPECT_EQ(Loops.size(), 2U);
+ CanonicalLoopInfo *InputLoop = Loops.front();
+ CanonicalLoopInfo *ScanLoop = Loops.back();
+ Builder.restoreIP(ScanLoop->getAfterIP());
+ InputLoop->assertOK();
+ ScanLoop->assertOK();
+
+ EXPECT_EQ(ScanLoop->getAfter(), Builder.GetInsertBlock());
+ EXPECT_EQ(NumBodiesGenerated, 2U);
+ SmallVector<OpenMPIRBuilder::ReductionInfo> ReductionInfos = {
+ {Builder.getFloatTy(), OrigVar, ScanVar,
+ /*EvaluationKind=*/OpenMPIRBuilder::EvalKind::Scalar, sumReduction,
+ /*ReductionGenClang=*/nullptr, sumAtomicReduction}};
+ OpenMPIRBuilder::LocationDescription RedLoc({InputLoop->getAfterIP(), DL});
+ llvm::BasicBlock *Cont = splitBB(Builder, false, "omp.scan.loop.cont");
+ ASSERT_EXPECTED_INIT(InsertPointTy, retIp,
+ OMPBuilder.emitScanReduction(RedLoc, ReductionInfos));
+ Builder.restoreIP(retIp);
+ Builder.CreateBr(Cont);
+ Builder.SetInsertPoint(Cont);
+ unsigned NumMallocs = 0;
+ unsigned NumFrees = 0;
+ unsigned NumMasked = 0;
+ unsigned NumEndMasked = 0;
+ unsigned NumLog = 0;
+ unsigned NumCeil = 0;
+ for (Instruction &I : instructions(F)) {
+ if (isa<CallInst>(I)) {
+ CallInst *Call = dyn_cast<CallInst>(&I);
+ auto Name = Call->getCalledFunction()->getName();
+ if (Name.equals_insensitive("malloc")) {
+ NumMallocs += 1;
+ } else if (Name.equals_insensitive("free")) {
+ NumFrees += 1;
+ } else if (Name.equals_insensitive("__kmpc_masked")) {
+ NumMasked += 1;
+ } else if (Name.equals_insensitive("__kmpc_end_masked")) {
+ NumEndMasked += 1;
+ } else if (Name.equals_insensitive("llvm.log2.f64")) {
+ NumLog += 1;
+ } else if (Name.equals_insensitive("llvm.ceil.f64")) {
+ NumCeil += 1;
+ }
+ }
+ }
+ EXPECT_EQ(NumBodiesGenerated, 2U);
+ EXPECT_EQ(NumMasked, 3U);
+ EXPECT_EQ(NumEndMasked, 3U);
+ EXPECT_EQ(NumMallocs, 1U);
+ EXPECT_EQ(NumFrees, 1U);
+ EXPECT_EQ(NumLog, 1U);
+ EXPECT_EQ(NumCeil, 1U);
+}
+
TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
More information about the llvm-commits
mailing list