[Mlir-commits] [llvm] [mlir] scan lowering changes (PR #133149)
Anchu Rajendran S
llvmlistbot at llvm.org
Wed Mar 26 13:10:32 PDT 2025
https://github.com/anchuraj created https://github.com/llvm/llvm-project/pull/133149
None
>From af71598bf25890f111dc47676edc381f5389421e Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Tue, 25 Mar 2025 11:21:37 -0500
Subject: [PATCH] scan lowering changes
---
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 58 ++-
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 338 +++++++++++++++++-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 241 ++++++++++---
3 files changed, 575 insertions(+), 62 deletions(-)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 28909cef4748d..3e5916e18a444 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -30,6 +30,7 @@
namespace llvm {
class CanonicalLoopInfo;
struct TargetRegionEntryInfo;
+class ScanInfo;
class OffloadEntriesInfoManager;
class OpenMPIRBuilder;
@@ -728,6 +729,11 @@ class OpenMPIRBuilder {
LoopBodyGenCallbackTy BodyGenCB, Value *TripCount,
const Twine &Name = "loop");
+ Expected<SmallVector<llvm::CanonicalLoopInfo *>> createCanonicalScanLoops(
+ const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
+ Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
+ InsertPointTy ComputeIP, const Twine &Name, bool InScan);
+
/// Calculate the trip count of a canonical loop.
///
/// This allows specifying user-defined loop counter values using increment,
@@ -800,10 +806,12 @@ class OpenMPIRBuilder {
///
/// \returns An object representing the created control flow structure which
/// can be used for loop-associated directives.
- Expected<CanonicalLoopInfo *> createCanonicalLoop(
- const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
- Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
- InsertPointTy ComputeIP = {}, const Twine &Name = "loop");
+ Expected<CanonicalLoopInfo *>
+ createCanonicalLoop(const LocationDescription &Loc,
+ LoopBodyGenCallbackTy BodyGenCB, Value *Start,
+ Value *Stop, Value *Step, bool IsSigned,
+ bool InclusiveStop, InsertPointTy ComputeIP = {},
+ const Twine &Name = "loop", bool InScan = false);
/// Collapse a loop nest into a single loop.
///
@@ -1530,6 +1538,16 @@ class OpenMPIRBuilder {
ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
Function *ReduceFn, AttributeList FuncAttrs);
+ llvm::CallInst *emitNoUnwindRuntimeCall(llvm::FunctionCallee callee,
+ ArrayRef<llvm::Value *> args,
+ const llvm::Twine &name);
+
+ void createScanBBs();
+ void emitScanBasedDirectiveDeclsIR(llvm::Value *span,
+ ArrayRef<llvm::Value *> ScanVars);
+ void emitScanBasedDirectiveFinalsIR(
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos);
+
/// This function emits a helper that gathers Reduce lists from the first
/// lane of every active warp to lanes in the first warp.
///
@@ -1981,7 +1999,7 @@ class OpenMPIRBuilder {
/// in reductions.
/// \param ReductionInfos A list of info on each reduction variable.
/// \param IsNoWait A flag set if the reduction is marked as nowait.
- /// \param IsByRef A flag set if the reduction is using reference
+ /// \param IsByRef At flag set if the reduction is using reference
/// or direct value.
InsertPointOrErrorTy createReductions(const LocationDescription &Loc,
InsertPointTy AllocaIP,
@@ -2177,7 +2195,6 @@ class OpenMPIRBuilder {
// block, if possible, or else at the end of the function. Also add a branch
// from current block to BB if current block does not have a terminator.
void emitBlock(BasicBlock *BB, Function *CurFn, bool IsFinished = false);
-
/// Emits code for OpenMP 'if' clause using specified \a BodyGenCallbackTy
/// Here is the logic:
/// if (Cond) {
@@ -2602,7 +2619,18 @@ class OpenMPIRBuilder {
InsertPointOrErrorTy createMasked(const LocationDescription &Loc,
BodyGenCallbackTy BodyGenCB,
FinalizeCallbackTy FiniCB, Value *Filter);
-
+ InsertPointOrErrorTy emitScanReduction(
+ const LocationDescription &Loc, InsertPointTy &FinalizeIP,
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos);
+
+ Expected<SmallVector<llvm::CanonicalLoopInfo *>> emitScanBasedDirectiveIR(
+ llvm::function_ref<Expected<CanonicalLoopInfo *>()> FirstGen,
+ llvm::function_ref<Expected<CanonicalLoopInfo *>(LocationDescription loc)>
+ SecondGen);
+ InsertPointOrErrorTy createScan(const LocationDescription &Loc,
+ InsertPointTy AllocaIP,
+ ArrayRef<llvm::Value *> ScanVars,
+ const Twine &Name, bool IsInclusive);
/// Generator for '#omp critical'
///
/// \param Loc The insert and source location description.
@@ -3711,6 +3739,22 @@ class CanonicalLoopInfo {
void invalidate();
};
+class ScanInfo {
+public:
+ llvm::BasicBlock *OMPBeforeScanBlock = nullptr;
+ llvm::BasicBlock *OMPAfterScanBlock = nullptr;
+ llvm::BasicBlock *OMPScanExitBlock = nullptr;
+ llvm::BasicBlock *OMPScanDispatch = nullptr;
+ llvm::BasicBlock *OMPScanLoopExit = nullptr;
+ bool OMPFirstScanLoop = false;
+ llvm::SmallDenseMap<llvm::Value *, llvm::Value *> ReductionVarToScanBuffs;
+ SmallVector<llvm::Value *> privateReductionVariables;
+ SmallVector<llvm::Value *> originalReductionVariables;
+ llvm::Value *iv;
+ llvm::Value *span;
+ SmallVector<llvm::BasicBlock *> continueBlocks;
+};
+
} // end namespace llvm
#endif // LLVM_FRONTEND_OPENMP_OMPIRBUILDER_H
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 2e5ce5308eea5..c74727585d4d9 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -67,6 +67,8 @@
using namespace llvm;
using namespace omp;
+class ScanInfo scanInfo;
+
static cl::opt<bool>
OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
cl::desc("Use optimistic attributes describing "
@@ -80,6 +82,7 @@ static cl::opt<double> UnrollThresholdFactor(
cl::init(1.5));
#ifndef NDEBUG
+
/// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
/// at position IP1 may change the meaning of IP2 or vice-versa. This is because
/// an InsertPoint stores the instruction before something is inserted. For
@@ -3918,6 +3921,272 @@ OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
/*Conditional*/ true, /*hasFinalize*/ true);
}
+llvm::CallInst *
+OpenMPIRBuilder::emitNoUnwindRuntimeCall(llvm::FunctionCallee callee,
+ ArrayRef<llvm::Value *> args,
+ const llvm::Twine &name) {
+ llvm::CallInst *call = Builder.CreateCall(
+ callee, args, SmallVector<llvm::OperandBundleDef, 1>(), name);
+ call->setDoesNotThrow();
+ return call;
+}
+
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
+ const LocationDescription &Loc, InsertPointTy AllocaIP,
+ ArrayRef<llvm::Value *> ScanVars, const Twine &Name, bool IsInclusive) {
+ if (scanInfo.OMPFirstScanLoop) {
+ Builder.restoreIP(AllocaIP);
+ emitScanBasedDirectiveDeclsIR(scanInfo.span, ScanVars);
+ }
+ if (!updateToLocation(Loc))
+ return Loc.IP;
+ unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
+ llvm::Value *iv = scanInfo.iv;
+
+ if (scanInfo.OMPFirstScanLoop) {
+ // Emit buffer[i] = red; at the end of the input phase.
+ for (int i = 0; i < ScanVars.size(); i++) {
+ auto buff = scanInfo.ReductionVarToScanBuffs[ScanVars[i]];
+
+ auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
+ auto val = Builder.CreateInBoundsGEP(destTy, buff, iv, "arrayOffset");
+ auto src = Builder.CreateLoad(destTy, ScanVars[i]);
+ auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ val, destTy->getPointerTo(defaultAS));
+
+ Builder.CreateStore(src, dest);
+ }
+ }
+ Builder.CreateBr(scanInfo.OMPScanLoopExit);
+ llvm::LLVMContext &llvmContext = Builder.getContext();
+ Builder.SetInsertPoint(scanInfo.OMPScanDispatch);
+
+ ConstantInt *Zero = ConstantInt::get(Builder.getInt32Ty(), 0);
+ for (int i = 0; i < ScanVars.size(); i++) {
+ auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
+ auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ScanVars[i], destTy->getPointerTo(defaultAS));
+ Builder.CreateStore(Zero, dest);
+ }
+
+ if (!scanInfo.OMPFirstScanLoop) {
+ iv = scanInfo.iv;
+ // Emit red = buffer[i]; at the entrance to the scan phase.
+ for (int i = 0; i < ScanVars.size(); i++) {
+ // x = buffer[i]
+ auto buff = scanInfo.ReductionVarToScanBuffs[ScanVars[i]];
+ auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
+ auto newVPtr = Builder.CreateInBoundsGEP(destTy, buff, iv, "arrayOffset");
+ auto newV = Builder.CreateLoad(destTy, newVPtr);
+ auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ScanVars[i], destTy->getPointerTo(defaultAS));
+
+ Builder.CreateStore(newV, dest);
+ }
+ }
+ llvm::Value *testCondVal1 =
+ llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/100);
+ llvm::Value *testCondVal2 =
+ llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
+ llvm::Value *CmpI = Builder.CreateICmpUGE(testCondVal1, testCondVal2);
+ if (scanInfo.OMPFirstScanLoop == IsInclusive) {
+ Builder.CreateCondBr(CmpI, scanInfo.OMPBeforeScanBlock,
+ scanInfo.OMPAfterScanBlock);
+ } else {
+ Builder.CreateCondBr(CmpI, scanInfo.OMPAfterScanBlock,
+ scanInfo.OMPBeforeScanBlock);
+ }
+ emitBlock(scanInfo.OMPAfterScanBlock, Builder.GetInsertBlock()->getParent());
+ Builder.SetInsertPoint(scanInfo.OMPAfterScanBlock);
+ return Builder.saveIP();
+}
+
+void OpenMPIRBuilder::emitScanBasedDirectiveDeclsIR(
+ llvm::Value *span, ArrayRef<llvm::Value *> ScanVars) {
+
+ ConstantInt *One = ConstantInt::get(Builder.getInt32Ty(), 1);
+ llvm::Value *allocSpan = Builder.CreateAdd(span, One);
+ for (auto &scanVar : ScanVars) {
+ // TODO: remove after all users of by-ref are updated to use the alloc
+ // region: Allocate reduction variable (which is a pointer to the real
+ // reduciton variable allocated in the inlined region)
+ llvm::Value *buff =
+ Builder.CreateAlloca(Builder.getInt32Ty(), allocSpan, "vla");
+ scanInfo.ReductionVarToScanBuffs[scanVar] = buff;
+ }
+}
+
+void OpenMPIRBuilder::emitScanBasedDirectiveFinalsIR(
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos) {
+ llvm::Value *span = scanInfo.span;
+ // llvm::Value *OMPLast = span;
+ llvm::Value *OMPLast = Builder.CreateNSWAdd(
+ span, llvm::ConstantInt::get(span->getType(), 1, /*isSigned=*/false));
+ // llvm::Value *OMPLast = Builder.CreateNSWSub(
+ // span, llvm::ConstantInt::get(span->getType(), 1, /*isSigned=*/false));
+ unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
+ for (int i = 0; i < reductionInfos.size(); i++) {
+ auto privateVar = reductionInfos[i].PrivateVariable;
+ auto origVar = reductionInfos[i].Variable;
+ auto buff = scanInfo.ReductionVarToScanBuffs[privateVar];
+ // newV = Builder.CreateLoad(builder.getPtrTy(), newV);
+
+ // if (!offsetIdx.empty())
+ auto srcTy = reductionInfos[i].ElementType;
+ auto val = Builder.CreateInBoundsGEP(srcTy, buff, OMPLast, "arrayOffset");
+ auto src = Builder.CreateLoad(srcTy, val);
+ auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ origVar, srcTy->getPointerTo(defaultAS));
+
+ Builder.CreateStore(src, dest);
+ }
+}
+
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
+ const LocationDescription &Loc, InsertPointTy &FinalizeIP,
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos) {
+
+ llvm::Value *spanDiff = scanInfo.span;
+
+ if (!updateToLocation(Loc))
+ return Loc.IP;
+ auto curFn = Builder.GetInsertBlock()->getParent();
+ // for (int k = 0; k <= ceil(log2(n)); ++k)
+ llvm::BasicBlock *LoopBB =
+ BasicBlock::Create(curFn->getContext(), "omp.outer.log.scan.body");
+ llvm::BasicBlock *ExitBB =
+ BasicBlock::Create(curFn->getContext(), "omp.outer.log.scan.exit");
+ llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration(
+ Builder.GetInsertBlock()->getModule(),
+ (llvm::Intrinsic::ID)llvm::Intrinsic::log2, Builder.getDoubleTy());
+ llvm::BasicBlock *InputBB = Builder.GetInsertBlock();
+ ConstantInt *One = ConstantInt::get(Builder.getInt32Ty(), 1);
+ llvm::Value *span = Builder.CreateAdd(spanDiff, One);
+ llvm::Value *Arg = Builder.CreateUIToFP(span, Builder.getDoubleTy());
+ llvm::Value *LogVal = emitNoUnwindRuntimeCall(F, Arg, "");
+ F = llvm::Intrinsic::getOrInsertDeclaration(
+ Builder.GetInsertBlock()->getModule(),
+ (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Builder.getDoubleTy());
+ LogVal = emitNoUnwindRuntimeCall(F, LogVal, "");
+ LogVal = Builder.CreateFPToUI(LogVal, Builder.getInt32Ty());
+ llvm::Value *NMin1 =
+ Builder.CreateNUWSub(span, llvm::ConstantInt::get(span->getType(), 1));
+ Builder.SetInsertPoint(InputBB);
+ Builder.CreateBr(LoopBB);
+ emitBlock(LoopBB, Builder.GetInsertBlock()->getParent());
+ Builder.SetInsertPoint(LoopBB);
+
+ auto *Counter = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+ //// size pow2k = 1;
+ auto *Pow2K = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+ Counter->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 0),
+ InputBB);
+ Pow2K->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 1), InputBB);
+ //// for (size i = n - 1; i >= 2 ^ k; --i)
+ //// tmp[i] op= tmp[i-pow2k];
+ llvm::BasicBlock *InnerLoopBB =
+ BasicBlock::Create(curFn->getContext(), "omp.inner.log.scan.body");
+ llvm::BasicBlock *InnerExitBB =
+ BasicBlock::Create(curFn->getContext(), "omp.inner.log.scan.exit");
+ llvm::Value *CmpI = Builder.CreateICmpUGE(NMin1, Pow2K);
+ Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
+ emitBlock(InnerLoopBB, Builder.GetInsertBlock()->getParent());
+ Builder.SetInsertPoint(InnerLoopBB);
+ auto *IVal = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+ IVal->addIncoming(NMin1, LoopBB);
+ unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
+ for (int i = 0; i < reductionInfos.size(); i++) {
+ auto &reductionVal = reductionInfos[i].PrivateVariable;
+ auto buff = scanInfo.ReductionVarToScanBuffs[reductionVal];
+ auto destTy = reductionInfos[i].ElementType;
+ Value *IV = Builder.CreateAdd(IVal, Builder.getInt32(1));
+ auto lhsPtr = Builder.CreateInBoundsGEP(destTy, buff, IV, "arrayOffset");
+ auto offsetIval = Builder.CreateNUWSub(IV, Pow2K);
+ auto rhsPtr =
+ Builder.CreateInBoundsGEP(destTy, buff, offsetIval, "arrayOffset");
+ auto lhs = Builder.CreateLoad(destTy, lhsPtr);
+ auto rhs = Builder.CreateLoad(destTy, rhsPtr);
+ auto lhsAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ lhsPtr, rhs->getType()->getPointerTo(defaultAS));
+ llvm::Value *result;
+ InsertPointOrErrorTy AfterIP =
+ reductionInfos[i].ReductionGen(Builder.saveIP(), lhs, rhs, result);
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.CreateStore(result, lhsAddr);
+ }
+ llvm::Value *NextIVal = Builder.CreateNUWSub(
+ IVal, llvm::ConstantInt::get(Builder.getInt32Ty(), 1));
+ IVal->addIncoming(NextIVal, Builder.GetInsertBlock());
+ CmpI = Builder.CreateICmpUGE(NextIVal, Pow2K);
+ Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
+ emitBlock(InnerExitBB, Builder.GetInsertBlock()->getParent());
+ llvm::Value *Next = Builder.CreateNUWAdd(
+ Counter, llvm::ConstantInt::get(Counter->getType(), 1));
+ Counter->addIncoming(Next, Builder.GetInsertBlock());
+ // pow2k <<= 1;
+ llvm::Value *NextPow2K = Builder.CreateShl(Pow2K, 1, "", /*HasNUW=*/true);
+ Pow2K->addIncoming(NextPow2K, Builder.GetInsertBlock());
+ llvm::Value *Cmp = Builder.CreateICmpNE(Next, LogVal);
+ Builder.CreateCondBr(Cmp, LoopBB, ExitBB);
+ emitBlock(ExitBB, Builder.GetInsertBlock()->getParent());
+ Builder.SetInsertPoint(ExitBB);
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
+ createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
+
+ Builder.restoreIP(FinalizeIP);
+ emitScanBasedDirectiveFinalsIR(reductionInfos);
+ FinalizeIP = Builder.saveIP();
+
+ return afterIP;
+}
+
+Expected<SmallVector<llvm::CanonicalLoopInfo *>>
+OpenMPIRBuilder::emitScanBasedDirectiveIR(
+ llvm::function_ref<Expected<CanonicalLoopInfo *>()> FirstGen,
+ llvm::function_ref<Expected<CanonicalLoopInfo *>(LocationDescription loc)>
+ SecondGen) {
+
+ SmallVector<llvm::CanonicalLoopInfo *> ret;
+ {
+ // Emit loop with input phase:
+ // #pragma omp ...
+ // for (i: 0..<num_iters>) {
+ // <input phase>;
+ // buffer[i] = red;
+ // }
+ scanInfo.OMPFirstScanLoop = true;
+ auto result = FirstGen();
+ if (result.takeError())
+ return result.takeError();
+ Builder.restoreIP((*result)->getAfterIP());
+ ret.push_back(*result);
+ }
+ {
+ scanInfo.OMPFirstScanLoop = false;
+ auto result = SecondGen(Builder.saveIP());
+ if (result.takeError())
+ return result.takeError();
+ Builder.restoreIP((*result)->getAfterIP());
+ ret.push_back(*result);
+ }
+ return ret;
+}
+
+void OpenMPIRBuilder::createScanBBs() {
+ auto fun = Builder.GetInsertBlock()->getParent();
+ scanInfo.OMPScanExitBlock =
+ BasicBlock::Create(fun->getContext(), "omp.exit.inscan.bb");
+ scanInfo.OMPScanDispatch =
+ BasicBlock::Create(fun->getContext(), "omp.inscan.dispatch");
+ scanInfo.OMPAfterScanBlock =
+ BasicBlock::Create(fun->getContext(), "omp.after.scan.bb");
+ scanInfo.OMPBeforeScanBlock =
+ BasicBlock::Create(fun->getContext(), "omp.before.scan.bb");
+ scanInfo.OMPScanLoopExit =
+ BasicBlock::Create(fun->getContext(), "omp.scan.loop.exit");
+}
+
CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
BasicBlock *PostInsertBefore, const Twine &Name) {
@@ -4015,10 +4284,72 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
return CL;
}
+Expected<SmallVector<llvm::CanonicalLoopInfo *>>
+OpenMPIRBuilder::createCanonicalScanLoops(
+ const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
+ Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
+ InsertPointTy ComputeIP, const Twine &Name, bool InScan) {
+ auto *IndVarTy = cast<IntegerType>(Start->getType());
+ assert(IndVarTy == Stop->getType() && "Stop type mismatch");
+ assert(IndVarTy == Step->getType() && "Step type mismatch");
+ LocationDescription ComputeLoc =
+ ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
+ updateToLocation(ComputeLoc);
+
+ ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
+
+ // Distance between Start and Stop; always positive.
+ Value *Span;
+
+ // Condition whether there are no iterations are executed at all, e.g. because
+ // UB < LB.
+
+ if (IsSigned) {
+ // Ensure that increment is positive. If not, negate and invert LB and UB.
+ Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
+ Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
+ Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
+ Span = Builder.CreateSub(UB, LB, "", false, true);
+ } else {
+ Span = Builder.CreateSub(Stop, Start, "", true);
+ }
+ auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
+ if (InScan) {
+ scanInfo.iv = IV;
+ createScanBBs();
+ auto terminator = Builder.GetInsertBlock()->getTerminator();
+ scanInfo.continueBlocks.push_back(terminator->getSuccessor(0));
+ terminator->setSuccessor(0, scanInfo.OMPScanDispatch);
+ emitBlock(scanInfo.OMPBeforeScanBlock,
+ Builder.GetInsertBlock()->getParent());
+ Builder.CreateBr(scanInfo.OMPScanLoopExit);
+
+ emitBlock(scanInfo.OMPScanLoopExit,
+ Builder.GetInsertBlock()->getParent());
+ Builder.CreateBr(scanInfo.continueBlocks.back());
+ emitBlock(scanInfo.OMPScanDispatch,
+ Builder.GetInsertBlock()->getParent());
+ Builder.SetInsertPoint(
+ scanInfo.OMPBeforeScanBlock->getFirstInsertionPt());
+ }
+ return BodyGenCB(Builder.saveIP(), IV);
+ };
+ const auto &&FirstGen = [&]() {
+ return createCanonicalLoop(Loc, BodyGen, Start, Stop, Step, IsSigned,
+ InclusiveStop, ComputeIP, Name, true);
+ };
+ const auto &&SecondGen = [&](LocationDescription loc) {
+ return createCanonicalLoop(loc, BodyGen, Start, Stop, Step, IsSigned,
+ InclusiveStop, ComputeIP, Name, true);
+ };
+ scanInfo.span = Span;
+ auto result = emitScanBasedDirectiveIR(FirstGen, SecondGen);
+ return result;
+}
+
Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
bool IsSigned, bool InclusiveStop, const Twine &Name) {
-
// Consider the following difficulties (assuming 8-bit signed integers):
// * Adding \p Step to the loop counter which passes \p Stop may overflow:
// DO I = 1, 100, 50
@@ -4078,7 +4409,7 @@ Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
- InsertPointTy ComputeIP, const Twine &Name) {
+ InsertPointTy ComputeIP, const Twine &Name, bool InScan) {
LocationDescription ComputeLoc =
ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
@@ -4089,6 +4420,9 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
Builder.restoreIP(CodeGenIP);
Value *Span = Builder.CreateMul(IV, Step);
Value *IndVar = Builder.CreateAdd(Span, Start);
+ if (InScan) {
+ scanInfo.iv = IndVar;
+ }
return BodyGenCB(Builder.saveIP(), IndVar);
};
LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d41489921bd13..d5aabf8b75ae1 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -86,7 +86,9 @@ class OpenMPLoopInfoStackFrame
: public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
- llvm::CanonicalLoopInfo *loopInfo = nullptr;
+ // For constructs like scan, one Loop info frame can contain multiple
+ // Canonical Loops
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
};
/// Custom error class to signal translation errors that don't need reporting,
@@ -232,8 +234,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
op.getReductionSyms())
result = todo("reduction");
if (op.getReductionMod() &&
- op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
- result = todo("reduction with modifier");
+ op.getReductionMod().value() == omp::ReductionModifier::task)
+ result = todo("reduction with task modifier");
};
auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
@@ -383,15 +385,15 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
/// Find the loop information structure for the loop nest being translated. It
/// will return a `null` value unless called from the translation function for
/// a loop wrapper operation after successfully translating its body.
-static llvm::CanonicalLoopInfo *
-findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
- llvm::CanonicalLoopInfo *loopInfo = nullptr;
+static SmallVector<llvm::CanonicalLoopInfo *>
+findCurrentLoopInfos(LLVM::ModuleTranslation &moduleTranslation) {
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
[&](OpenMPLoopInfoStackFrame &frame) {
- loopInfo = frame.loopInfo;
+ loopInfos = frame.loopInfos;
return WalkResult::interrupt();
});
- return loopInfo;
+ return loopInfos;
}
/// Converts the given region that appears within an OpenMP dialect operation to
@@ -2258,26 +2260,61 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(regionBlock, opInst)))
return failure();
- builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
- llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
-
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
- ompBuilder->applyWorkshareLoop(
- ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
- convertToScheduleKind(schedule), chunk, isSimd,
- scheduleMod == omp::ScheduleModifier::monotonic,
- scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
- workshareLoopType);
-
- if (failed(handleError(wsloopIP, opInst)))
- return failure();
-
- // Process the reductions if required.
- if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
- allocaIP, reductionDecls,
- privateReductionVariables, isByRef)))
- return failure();
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+ findCurrentLoopInfos(moduleTranslation);
+ auto beforeLoopFinishIp = loopInfos.front()->getAfterIP();
+ auto afterLoopFinishIp = loopInfos.back()->getAfterIP();
+ bool isInScanRegion =
+ wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
+ mlir::omp::ReductionModifier::inscan);
+ if (isInScanRegion) {
+ builder.restoreIP(beforeLoopFinishIp);
+ SmallVector<OwningReductionGen> owningReductionGens;
+ SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+ collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
+ owningReductionGens, owningAtomicReductionGens,
+ privateReductionVariables, reductionInfos);
+ llvm::BasicBlock *cont = splitBB(builder, false, "omp.scan.loop.cont");
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP =
+ ompBuilder->emitScanReduction(builder.saveIP(), afterLoopFinishIp,
+ reductionInfos);
+ if (failed(handleError(redIP, opInst)))
+ return failure();
+ builder.restoreIP(*redIP);
+ builder.CreateBr(cont);
+ }
+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+ ompBuilder->applyWorkshareLoop(
+ ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+ convertToScheduleKind(schedule), chunk, isSimd,
+ scheduleMod == omp::ScheduleModifier::monotonic,
+ scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+ workshareLoopType);
+
+ if (failed(handleError(wsloopIP, opInst)))
+ return failure();
+ }
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+ if (isInScanRegion) {
+ SmallVector<Region *> reductionRegions;
+ llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
+ [](omp::DeclareReductionOp reductionDecl) {
+ return &reductionDecl.getCleanupRegion();
+ });
+ if (failed(inlineOmpRegionCleanup(
+ reductionRegions, privateReductionVariables, moduleTranslation,
+ builder, "omp.reduction.cleanup")))
+ return failure();
+ } else {
+ // Process the reductions if required.
+ if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
+ allocaIP, reductionDecls,
+ privateReductionVariables, isByRef)))
+ return failure();
+ }
return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
privateVarsInfo.llvmVars,
privateVarsInfo.privatizers);
@@ -2467,6 +2504,59 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
llvm_unreachable("Unknown ClauseOrderKind kind");
}
+static LogicalResult
+convertOmpScan(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ if (failed(checkImplementationStatus(opInst)))
+ return failure();
+ auto scanOp = cast<omp::ScanOp>(opInst);
+ bool isInclusive = scanOp.hasInclusiveVars();
+ SmallVector<llvm::Value *> llvmScanVars;
+ mlir::OperandRange mlirScanVars = scanOp.getInclusiveVars();
+ if (!isInclusive)
+ mlirScanVars = scanOp.getExclusiveVars();
+ for (auto val : mlirScanVars) {
+ llvm::Value *llvmVal = moduleTranslation.lookupValue(val);
+ llvmScanVars.push_back(llvmVal);
+ }
+ llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
+ findAllocaInsertPoint(builder, moduleTranslation);
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
+ moduleTranslation.getOpenMPBuilder()->createScan(
+ ompLoc, allocaIP, llvmScanVars, "scan", isInclusive);
+ if (failed(handleError(afterIP, opInst)))
+ return failure();
+
+ builder.restoreIP(*afterIP);
+
+ // TODO: The argument of LoopnestOp is stored into the index variable and this
+ // variable is used
+ // across scan operation. However that makes the mlir
+ // invalid.(`Intra-iteration dependences from a statement in the structured
+ // block sequence that precede a scan directive to a statement in the
+ // structured block sequence that follows a scan directive must not exist,
+ // except for dependences for the list items specified in an inclusive or
+ // exclusive clause.`). The argument of LoopNestOp need to be loaded again
+ // after ScanOp again so mlir generated is valid.
+ auto parentOp = scanOp->getParentOp();
+ auto loopOp = cast<omp::LoopNestOp>(parentOp);
+ if (loopOp) {
+ auto &firstBlock = *(scanOp->getParentRegion()->getBlocks()).begin();
+ auto &ins = *(firstBlock.begin());
+ if (isa<LLVM::StoreOp>(ins)) {
+ LLVM::StoreOp storeOp = dyn_cast<LLVM::StoreOp>(ins);
+ auto src = moduleTranslation.lookupValue(storeOp->getOperand(0));
+ if (src == moduleTranslation.lookupValue(
+ (loopOp.getRegion().getArguments())[0])) {
+ auto dest = moduleTranslation.lookupValue(storeOp->getOperand(1));
+ builder.CreateStore(src, dest);
+ }
+ }
+ }
+ return success();
+}
+
/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -2540,12 +2630,15 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
return failure();
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
- llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
- ompBuilder->applySimd(loopInfo, alignedVars,
- simdOp.getIfExpr()
- ? moduleTranslation.lookupValue(simdOp.getIfExpr())
- : nullptr,
- order, simdlen, safelen);
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+ findCurrentLoopInfos(moduleTranslation);
+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+ ompBuilder->applySimd(
+ loopInfo, alignedVars,
+ simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+ : nullptr,
+ order, simdlen, safelen);
+ }
return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
privateVarsInfo.llvmVars,
@@ -2612,16 +2705,52 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
ompLoc.DL);
computeIP = loopInfos.front()->getPreheaderIP();
}
+ if (auto wsloopOp = loopOp->getParentOfType<omp::WsloopOp>()) {
+ bool isInScanRegion =
+ wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
+ mlir::omp::ReductionModifier::inscan);
+ // TODOSCAN: Take care of loop and add asserts if required
+ if (isInScanRegion) {
+ llvm::Expected<SmallVector<llvm::CanonicalLoopInfo *>> loopResults =
+ ompBuilder->createCanonicalScanLoops(
+ loc, bodyGen, lowerBound, upperBound, step,
+ /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP, "loop",
+ isInScanRegion);
+
+ if (failed(handleError(loopResults, *loopOp)))
+ return failure();
+ auto beforeLoop = loopResults->front();
+ auto afterLoop = loopResults->back();
+ moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+ [&](OpenMPLoopInfoStackFrame &frame) {
+ frame.loopInfos.push_back(beforeLoop);
+ frame.loopInfos.push_back(afterLoop);
+ return WalkResult::interrupt();
+ });
+ builder.restoreIP(afterLoop->getAfterIP());
+ return success();
+ } else {
+ llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
+ ompBuilder->createCanonicalLoop(
+ loc, bodyGen, lowerBound, upperBound, step,
+ /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
- llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
- ompBuilder->createCanonicalLoop(
- loc, bodyGen, lowerBound, upperBound, step,
- /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
+ if (failed(handleError(loopResult, *loopOp)))
+ return failure();
- if (failed(handleError(loopResult, *loopOp)))
- return failure();
+ loopInfos.push_back(*loopResult);
+ }
+ } else {
+ llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
+ ompBuilder->createCanonicalLoop(
+ loc, bodyGen, lowerBound, upperBound, step,
+ /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
+
+ if (failed(handleError(loopResult, *loopOp)))
+ return failure();
- loopInfos.push_back(*loopResult);
+ loopInfos.push_back(*loopResult);
+ }
}
// Collapse loops. Store the insertion point because LoopInfos may get
@@ -2633,7 +2762,8 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
// after applying transformations.
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
[&](OpenMPLoopInfoStackFrame &frame) {
- frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
+ frame.loopInfos.push_back(
+ ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {}));
return WalkResult::interrupt();
});
@@ -4212,18 +4342,20 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
bool loopNeedsBarrier = false;
llvm::Value *chunk = nullptr;
- llvm::CanonicalLoopInfo *loopInfo =
- findCurrentLoopInfo(moduleTranslation);
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
- ompBuilder->applyWorkshareLoop(
- ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
- convertToScheduleKind(schedule), chunk, isSimd,
- scheduleMod == omp::ScheduleModifier::monotonic,
- scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
- workshareLoopType);
-
- if (!wsloopIP)
- return wsloopIP.takeError();
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+ findCurrentLoopInfos(moduleTranslation);
+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+ ompBuilder->applyWorkshareLoop(
+ ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+ convertToScheduleKind(schedule), chunk, isSimd,
+ scheduleMod == omp::ScheduleModifier::monotonic,
+ scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+ workshareLoopType);
+
+ if (!wsloopIP)
+ return wsloopIP.takeError();
+ }
}
if (failed(cleanupPrivateVars(builder, moduleTranslation,
@@ -5202,6 +5334,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
.Case([&](omp::WsloopOp) {
return convertOmpWsloop(*op, builder, moduleTranslation);
})
+ .Case([&](omp::ScanOp) {
+ return convertOmpScan(*op, builder, moduleTranslation);
+ })
.Case([&](omp::SimdOp) {
return convertOmpSimd(*op, builder, moduleTranslation);
})
More information about the Mlir-commits
mailing list