[Mlir-commits] [llvm] [mlir] scan lowering changes (PR #133149)
Anchu Rajendran S
llvmlistbot at llvm.org
Wed Mar 26 13:39:14 PDT 2025
https://github.com/anchuraj updated https://github.com/llvm/llvm-project/pull/133149
>From eceb4c85b7b020981cdbac353bd3e5f06aefeeec Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Tue, 25 Mar 2025 11:21:37 -0500
Subject: [PATCH] scan lowering changes
---
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 74 +++-
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 338 +++++++++++++++-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 242 +++++++++---
.../Target/LLVMIR/openmp-reduction-scan-out | 364 ++++++++++++++++++
.../Target/LLVMIR/openmp-reduction-scan.mlir | 66 ++++
5 files changed, 1023 insertions(+), 61 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/openmp-reduction-scan-out
create mode 100644 mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 28909cef4748d..c1cce4434ceb4 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -30,6 +30,7 @@
namespace llvm {
class CanonicalLoopInfo;
struct TargetRegionEntryInfo;
+class ScanInfo;
class OffloadEntriesInfoManager;
class OpenMPIRBuilder;
@@ -728,6 +729,11 @@ class OpenMPIRBuilder {
LoopBodyGenCallbackTy BodyGenCB, Value *TripCount,
const Twine &Name = "loop");
+ Expected<SmallVector<llvm::CanonicalLoopInfo *>> createCanonicalScanLoops(
+ const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
+ Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
+ InsertPointTy ComputeIP, const Twine &Name, bool InScan);
+
/// Calculate the trip count of a canonical loop.
///
/// This allows specifying user-defined loop counter values using increment,
@@ -800,10 +806,12 @@ class OpenMPIRBuilder {
///
/// \returns An object representing the created control flow structure which
/// can be used for loop-associated directives.
- Expected<CanonicalLoopInfo *> createCanonicalLoop(
- const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
- Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
- InsertPointTy ComputeIP = {}, const Twine &Name = "loop");
+ Expected<CanonicalLoopInfo *>
+ createCanonicalLoop(const LocationDescription &Loc,
+ LoopBodyGenCallbackTy BodyGenCB, Value *Start,
+ Value *Stop, Value *Step, bool IsSigned,
+ bool InclusiveStop, InsertPointTy ComputeIP = {},
+ const Twine &Name = "loop", bool InScan = false);
/// Collapse a loop nest into a single loop.
///
@@ -1530,6 +1538,20 @@ class OpenMPIRBuilder {
ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
Function *ReduceFn, AttributeList FuncAttrs);
+ Expected<SmallVector<llvm::CanonicalLoopInfo *>> emitScanBasedDirectiveIR(
+ llvm::function_ref<Expected<CanonicalLoopInfo *>()> FirstGen,
+ llvm::function_ref<Expected<CanonicalLoopInfo *>(LocationDescription loc)>
+ SecondGen);
+ llvm::CallInst *emitNoUnwindRuntimeCall(llvm::FunctionCallee callee,
+ ArrayRef<llvm::Value *> args,
+ const llvm::Twine &name);
+
+ void createScanBBs();
+ void emitScanBasedDirectiveDeclsIR(llvm::Value *span,
+ ArrayRef<llvm::Value *> ScanVars);
+ void emitScanBasedDirectiveFinalsIR(
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos);
+
/// This function emits a helper that gathers Reduce lists from the first
/// lane of every active warp to lanes in the first warp.
///
@@ -1981,7 +2003,7 @@ class OpenMPIRBuilder {
/// in reductions.
/// \param ReductionInfos A list of info on each reduction variable.
/// \param IsNoWait A flag set if the reduction is marked as nowait.
- /// \param IsByRef A flag set if the reduction is using reference
+ /// \param IsByRef At flag set if the reduction is using reference
/// or direct value.
InsertPointOrErrorTy createReductions(const LocationDescription &Loc,
InsertPointTy AllocaIP,
@@ -2177,7 +2199,6 @@ class OpenMPIRBuilder {
// block, if possible, or else at the end of the function. Also add a branch
// from current block to BB if current block does not have a terminator.
void emitBlock(BasicBlock *BB, Function *CurFn, bool IsFinished = false);
-
/// Emits code for OpenMP 'if' clause using specified \a BodyGenCallbackTy
/// Here is the logic:
/// if (Cond) {
@@ -2603,6 +2624,31 @@ class OpenMPIRBuilder {
BodyGenCallbackTy BodyGenCB,
FinalizeCallbackTy FiniCB, Value *Filter);
+ /// Generator for the scan reduction
+ ///
+ /// \param Loc The insert and source location description.
+ /// \param FinalizeIP The IP where the reduction result needs
+ // to be copied back to original variable.
+ /// \param ReductionInfos Array type containing the ReductionOps.
+ ///
+ /// \returns The insertion position *after* the masked.
+ InsertPointOrErrorTy emitScanReduction(
+ const LocationDescription &Loc, InsertPointTy &FinalizeIP,
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos);
+
+ /// Generator for the scan reduction
+ ///
+ /// \param Loc The insert and source location description.
+ /// \param AllocaIP The IP where the temporary buffer for scan reduction
+ // needs to be allocated.
+ /// \param ScanVars Scan Variables.
+ /// \param IsInclusive Indicates if it is an inclusive or exclusive scan.
+ ///
+ /// \returns The insertion position *after* the masked.
+ InsertPointOrErrorTy createScan(const LocationDescription &Loc,
+ InsertPointTy AllocaIP,
+ ArrayRef<llvm::Value *> ScanVars,
+ bool IsInclusive);
/// Generator for '#omp critical'
///
/// \param Loc The insert and source location description.
@@ -3711,6 +3757,22 @@ class CanonicalLoopInfo {
void invalidate();
};
+class ScanInfo {
+public:
+ llvm::BasicBlock *OMPBeforeScanBlock = nullptr;
+ llvm::BasicBlock *OMPAfterScanBlock = nullptr;
+ llvm::BasicBlock *OMPScanExitBlock = nullptr;
+ llvm::BasicBlock *OMPScanDispatch = nullptr;
+ llvm::BasicBlock *OMPScanLoopExit = nullptr;
+ bool OMPFirstScanLoop = false;
+ llvm::SmallDenseMap<llvm::Value *, llvm::Value *> ReductionVarToScanBuffs;
+ SmallVector<llvm::Value *> privateReductionVariables;
+ SmallVector<llvm::Value *> originalReductionVariables;
+ llvm::Value *iv;
+ llvm::Value *span;
+ SmallVector<llvm::BasicBlock *> continueBlocks;
+};
+
} // end namespace llvm
#endif // LLVM_FRONTEND_OPENMP_OMPIRBUILDER_H
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 2e5ce5308eea5..05821d641efe2 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -67,6 +67,8 @@
using namespace llvm;
using namespace omp;
+class ScanInfo scanInfo;
+
static cl::opt<bool>
OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
cl::desc("Use optimistic attributes describing "
@@ -80,6 +82,7 @@ static cl::opt<double> UnrollThresholdFactor(
cl::init(1.5));
#ifndef NDEBUG
+
/// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
/// at position IP1 may change the meaning of IP2 or vice-versa. This is because
/// an InsertPoint stores the instruction before something is inserted. For
@@ -3918,6 +3921,272 @@ OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
/*Conditional*/ true, /*hasFinalize*/ true);
}
+llvm::CallInst *
+OpenMPIRBuilder::emitNoUnwindRuntimeCall(llvm::FunctionCallee callee,
+ ArrayRef<llvm::Value *> args,
+ const llvm::Twine &name) {
+ llvm::CallInst *call = Builder.CreateCall(
+ callee, args, SmallVector<llvm::OperandBundleDef, 1>(), name);
+ call->setDoesNotThrow();
+ return call;
+}
+
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
+ const LocationDescription &Loc, InsertPointTy AllocaIP,
+ ArrayRef<llvm::Value *> ScanVars, bool IsInclusive) {
+ if (scanInfo.OMPFirstScanLoop) {
+ Builder.restoreIP(AllocaIP);
+ emitScanBasedDirectiveDeclsIR(scanInfo.span, ScanVars);
+ }
+ if (!updateToLocation(Loc))
+ return Loc.IP;
+ unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
+ llvm::Value *iv = scanInfo.iv;
+
+ if (scanInfo.OMPFirstScanLoop) {
+ // Emit buffer[i] = red; at the end of the input phase.
+ for (int i = 0; i < ScanVars.size(); i++) {
+ auto buff = scanInfo.ReductionVarToScanBuffs[ScanVars[i]];
+
+ auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
+ auto val = Builder.CreateInBoundsGEP(destTy, buff, iv, "arrayOffset");
+ auto src = Builder.CreateLoad(destTy, ScanVars[i]);
+ auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ val, destTy->getPointerTo(defaultAS));
+
+ Builder.CreateStore(src, dest);
+ }
+ }
+ Builder.CreateBr(scanInfo.OMPScanLoopExit);
+ llvm::LLVMContext &llvmContext = Builder.getContext();
+ Builder.SetInsertPoint(scanInfo.OMPScanDispatch);
+
+ ConstantInt *Zero = ConstantInt::get(Builder.getInt32Ty(), 0);
+ for (int i = 0; i < ScanVars.size(); i++) {
+ auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
+ auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ScanVars[i], destTy->getPointerTo(defaultAS));
+ Builder.CreateStore(Zero, dest);
+ }
+
+ if (!scanInfo.OMPFirstScanLoop) {
+ iv = scanInfo.iv;
+ // Emit red = buffer[i]; at the entrance to the scan phase.
+ for (int i = 0; i < ScanVars.size(); i++) {
+ // x = buffer[i]
+ auto buff = scanInfo.ReductionVarToScanBuffs[ScanVars[i]];
+ auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
+ auto newVPtr = Builder.CreateInBoundsGEP(destTy, buff, iv, "arrayOffset");
+ auto newV = Builder.CreateLoad(destTy, newVPtr);
+ auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ScanVars[i], destTy->getPointerTo(defaultAS));
+
+ Builder.CreateStore(newV, dest);
+ }
+ }
+ llvm::Value *testCondVal1 =
+ llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/100);
+ llvm::Value *testCondVal2 =
+ llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
+ llvm::Value *CmpI = Builder.CreateICmpUGE(testCondVal1, testCondVal2);
+ if (scanInfo.OMPFirstScanLoop == IsInclusive) {
+ Builder.CreateCondBr(CmpI, scanInfo.OMPBeforeScanBlock,
+ scanInfo.OMPAfterScanBlock);
+ } else {
+ Builder.CreateCondBr(CmpI, scanInfo.OMPAfterScanBlock,
+ scanInfo.OMPBeforeScanBlock);
+ }
+ emitBlock(scanInfo.OMPAfterScanBlock, Builder.GetInsertBlock()->getParent());
+ Builder.SetInsertPoint(scanInfo.OMPAfterScanBlock);
+ return Builder.saveIP();
+}
+
+void OpenMPIRBuilder::emitScanBasedDirectiveDeclsIR(
+ llvm::Value *span, ArrayRef<llvm::Value *> ScanVars) {
+
+ ConstantInt *One = ConstantInt::get(Builder.getInt32Ty(), 1);
+ llvm::Value *allocSpan = Builder.CreateAdd(span, One);
+ for (auto &scanVar : ScanVars) {
+ // TODO: remove after all users of by-ref are updated to use the alloc
+ // region: Allocate reduction variable (which is a pointer to the real
+ // reduciton variable allocated in the inlined region)
+ llvm::Value *buff =
+ Builder.CreateAlloca(Builder.getInt32Ty(), allocSpan, "vla");
+ scanInfo.ReductionVarToScanBuffs[scanVar] = buff;
+ }
+}
+
+void OpenMPIRBuilder::emitScanBasedDirectiveFinalsIR(
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos) {
+ llvm::Value *span = scanInfo.span;
+ // llvm::Value *OMPLast = span;
+ llvm::Value *OMPLast = Builder.CreateNSWAdd(
+ span, llvm::ConstantInt::get(span->getType(), 1, /*isSigned=*/false));
+ // llvm::Value *OMPLast = Builder.CreateNSWSub(
+ // span, llvm::ConstantInt::get(span->getType(), 1, /*isSigned=*/false));
+ unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
+ for (int i = 0; i < reductionInfos.size(); i++) {
+ auto privateVar = reductionInfos[i].PrivateVariable;
+ auto origVar = reductionInfos[i].Variable;
+ auto buff = scanInfo.ReductionVarToScanBuffs[privateVar];
+ // newV = Builder.CreateLoad(builder.getPtrTy(), newV);
+
+ // if (!offsetIdx.empty())
+ auto srcTy = reductionInfos[i].ElementType;
+ auto val = Builder.CreateInBoundsGEP(srcTy, buff, OMPLast, "arrayOffset");
+ auto src = Builder.CreateLoad(srcTy, val);
+ auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ origVar, srcTy->getPointerTo(defaultAS));
+
+ Builder.CreateStore(src, dest);
+ }
+}
+
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
+ const LocationDescription &Loc, InsertPointTy &FinalizeIP,
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos) {
+
+ llvm::Value *spanDiff = scanInfo.span;
+
+ if (!updateToLocation(Loc))
+ return Loc.IP;
+ auto curFn = Builder.GetInsertBlock()->getParent();
+ // for (int k = 0; k <= ceil(log2(n)); ++k)
+ llvm::BasicBlock *LoopBB =
+ BasicBlock::Create(curFn->getContext(), "omp.outer.log.scan.body");
+ llvm::BasicBlock *ExitBB =
+ BasicBlock::Create(curFn->getContext(), "omp.outer.log.scan.exit");
+ llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration(
+ Builder.GetInsertBlock()->getModule(),
+ (llvm::Intrinsic::ID)llvm::Intrinsic::log2, Builder.getDoubleTy());
+ llvm::BasicBlock *InputBB = Builder.GetInsertBlock();
+ ConstantInt *One = ConstantInt::get(Builder.getInt32Ty(), 1);
+ llvm::Value *span = Builder.CreateAdd(spanDiff, One);
+ llvm::Value *Arg = Builder.CreateUIToFP(span, Builder.getDoubleTy());
+ llvm::Value *LogVal = emitNoUnwindRuntimeCall(F, Arg, "");
+ F = llvm::Intrinsic::getOrInsertDeclaration(
+ Builder.GetInsertBlock()->getModule(),
+ (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Builder.getDoubleTy());
+ LogVal = emitNoUnwindRuntimeCall(F, LogVal, "");
+ LogVal = Builder.CreateFPToUI(LogVal, Builder.getInt32Ty());
+ llvm::Value *NMin1 =
+ Builder.CreateNUWSub(span, llvm::ConstantInt::get(span->getType(), 1));
+ Builder.SetInsertPoint(InputBB);
+ Builder.CreateBr(LoopBB);
+ emitBlock(LoopBB, Builder.GetInsertBlock()->getParent());
+ Builder.SetInsertPoint(LoopBB);
+
+ auto *Counter = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+ //// size pow2k = 1;
+ auto *Pow2K = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+ Counter->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 0),
+ InputBB);
+ Pow2K->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 1), InputBB);
+ //// for (size i = n - 1; i >= 2 ^ k; --i)
+ //// tmp[i] op= tmp[i-pow2k];
+ llvm::BasicBlock *InnerLoopBB =
+ BasicBlock::Create(curFn->getContext(), "omp.inner.log.scan.body");
+ llvm::BasicBlock *InnerExitBB =
+ BasicBlock::Create(curFn->getContext(), "omp.inner.log.scan.exit");
+ llvm::Value *CmpI = Builder.CreateICmpUGE(NMin1, Pow2K);
+ Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
+ emitBlock(InnerLoopBB, Builder.GetInsertBlock()->getParent());
+ Builder.SetInsertPoint(InnerLoopBB);
+ auto *IVal = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+ IVal->addIncoming(NMin1, LoopBB);
+ unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
+ for (int i = 0; i < reductionInfos.size(); i++) {
+ auto &reductionVal = reductionInfos[i].PrivateVariable;
+ auto buff = scanInfo.ReductionVarToScanBuffs[reductionVal];
+ auto destTy = reductionInfos[i].ElementType;
+ Value *IV = Builder.CreateAdd(IVal, Builder.getInt32(1));
+ auto lhsPtr = Builder.CreateInBoundsGEP(destTy, buff, IV, "arrayOffset");
+ auto offsetIval = Builder.CreateNUWSub(IV, Pow2K);
+ auto rhsPtr =
+ Builder.CreateInBoundsGEP(destTy, buff, offsetIval, "arrayOffset");
+ auto lhs = Builder.CreateLoad(destTy, lhsPtr);
+ auto rhs = Builder.CreateLoad(destTy, rhsPtr);
+ auto lhsAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ lhsPtr, rhs->getType()->getPointerTo(defaultAS));
+ llvm::Value *result;
+ InsertPointOrErrorTy AfterIP =
+ reductionInfos[i].ReductionGen(Builder.saveIP(), lhs, rhs, result);
+ if (!AfterIP)
+ return AfterIP.takeError();
+ Builder.CreateStore(result, lhsAddr);
+ }
+ llvm::Value *NextIVal = Builder.CreateNUWSub(
+ IVal, llvm::ConstantInt::get(Builder.getInt32Ty(), 1));
+ IVal->addIncoming(NextIVal, Builder.GetInsertBlock());
+ CmpI = Builder.CreateICmpUGE(NextIVal, Pow2K);
+ Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
+ emitBlock(InnerExitBB, Builder.GetInsertBlock()->getParent());
+ llvm::Value *Next = Builder.CreateNUWAdd(
+ Counter, llvm::ConstantInt::get(Counter->getType(), 1));
+ Counter->addIncoming(Next, Builder.GetInsertBlock());
+ // pow2k <<= 1;
+ llvm::Value *NextPow2K = Builder.CreateShl(Pow2K, 1, "", /*HasNUW=*/true);
+ Pow2K->addIncoming(NextPow2K, Builder.GetInsertBlock());
+ llvm::Value *Cmp = Builder.CreateICmpNE(Next, LogVal);
+ Builder.CreateCondBr(Cmp, LoopBB, ExitBB);
+ emitBlock(ExitBB, Builder.GetInsertBlock()->getParent());
+ Builder.SetInsertPoint(ExitBB);
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
+ createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
+
+ Builder.restoreIP(FinalizeIP);
+ emitScanBasedDirectiveFinalsIR(reductionInfos);
+ FinalizeIP = Builder.saveIP();
+
+ return afterIP;
+}
+
+Expected<SmallVector<llvm::CanonicalLoopInfo *>>
+OpenMPIRBuilder::emitScanBasedDirectiveIR(
+ llvm::function_ref<Expected<CanonicalLoopInfo *>()> FirstGen,
+ llvm::function_ref<Expected<CanonicalLoopInfo *>(LocationDescription loc)>
+ SecondGen) {
+
+ SmallVector<llvm::CanonicalLoopInfo *> ret;
+ {
+ // Emit loop with input phase:
+ // #pragma omp ...
+ // for (i: 0..<num_iters>) {
+ // <input phase>;
+ // buffer[i] = red;
+ // }
+ scanInfo.OMPFirstScanLoop = true;
+ auto result = FirstGen();
+ if (result.takeError())
+ return result.takeError();
+ Builder.restoreIP((*result)->getAfterIP());
+ ret.push_back(*result);
+ }
+ {
+ scanInfo.OMPFirstScanLoop = false;
+ auto result = SecondGen(Builder.saveIP());
+ if (result.takeError())
+ return result.takeError();
+ Builder.restoreIP((*result)->getAfterIP());
+ ret.push_back(*result);
+ }
+ return ret;
+}
+
+void OpenMPIRBuilder::createScanBBs() {
+ auto fun = Builder.GetInsertBlock()->getParent();
+ scanInfo.OMPScanExitBlock =
+ BasicBlock::Create(fun->getContext(), "omp.exit.inscan.bb");
+ scanInfo.OMPScanDispatch =
+ BasicBlock::Create(fun->getContext(), "omp.inscan.dispatch");
+ scanInfo.OMPAfterScanBlock =
+ BasicBlock::Create(fun->getContext(), "omp.after.scan.bb");
+ scanInfo.OMPBeforeScanBlock =
+ BasicBlock::Create(fun->getContext(), "omp.before.scan.bb");
+ scanInfo.OMPScanLoopExit =
+ BasicBlock::Create(fun->getContext(), "omp.scan.loop.exit");
+}
+
CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
BasicBlock *PostInsertBefore, const Twine &Name) {
@@ -4015,10 +4284,72 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
return CL;
}
+Expected<SmallVector<llvm::CanonicalLoopInfo *>>
+OpenMPIRBuilder::createCanonicalScanLoops(
+ const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
+ Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
+ InsertPointTy ComputeIP, const Twine &Name, bool InScan) {
+ auto *IndVarTy = cast<IntegerType>(Start->getType());
+ assert(IndVarTy == Stop->getType() && "Stop type mismatch");
+ assert(IndVarTy == Step->getType() && "Step type mismatch");
+ LocationDescription ComputeLoc =
+ ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
+ updateToLocation(ComputeLoc);
+
+ ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
+
+ // Distance between Start and Stop; always positive.
+ Value *Span;
+
+ // Condition whether there are no iterations are executed at all, e.g. because
+ // UB < LB.
+
+ if (IsSigned) {
+ // Ensure that increment is positive. If not, negate and invert LB and UB.
+ Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
+ Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
+ Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
+ Span = Builder.CreateSub(UB, LB, "", false, true);
+ } else {
+ Span = Builder.CreateSub(Stop, Start, "", true);
+ }
+ auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
+ if (InScan) {
+ scanInfo.iv = IV;
+ createScanBBs();
+ auto terminator = Builder.GetInsertBlock()->getTerminator();
+ scanInfo.continueBlocks.push_back(terminator->getSuccessor(0));
+ terminator->setSuccessor(0, scanInfo.OMPScanDispatch);
+ emitBlock(scanInfo.OMPBeforeScanBlock,
+ Builder.GetInsertBlock()->getParent());
+ Builder.CreateBr(scanInfo.OMPScanLoopExit);
+
+ emitBlock(scanInfo.OMPScanLoopExit,
+ Builder.GetInsertBlock()->getParent());
+ Builder.CreateBr(scanInfo.continueBlocks.back());
+ emitBlock(scanInfo.OMPScanDispatch,
+ Builder.GetInsertBlock()->getParent());
+ Builder.SetInsertPoint(
+ scanInfo.OMPBeforeScanBlock->getFirstInsertionPt());
+ }
+ return BodyGenCB(Builder.saveIP(), IV);
+ };
+ const auto &&FirstGen = [&]() {
+ return createCanonicalLoop(Loc, BodyGen, Start, Stop, Step, IsSigned,
+ InclusiveStop, ComputeIP, Name, true);
+ };
+ const auto &&SecondGen = [&](LocationDescription loc) {
+ return createCanonicalLoop(loc, BodyGen, Start, Stop, Step, IsSigned,
+ InclusiveStop, ComputeIP, Name, true);
+ };
+ scanInfo.span = Span;
+ auto result = emitScanBasedDirectiveIR(FirstGen, SecondGen);
+ return result;
+}
+
Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
bool IsSigned, bool InclusiveStop, const Twine &Name) {
-
// Consider the following difficulties (assuming 8-bit signed integers):
// * Adding \p Step to the loop counter which passes \p Stop may overflow:
// DO I = 1, 100, 50
@@ -4078,7 +4409,7 @@ Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
- InsertPointTy ComputeIP, const Twine &Name) {
+ InsertPointTy ComputeIP, const Twine &Name, bool InScan) {
LocationDescription ComputeLoc =
ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
@@ -4089,6 +4420,9 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
Builder.restoreIP(CodeGenIP);
Value *Span = Builder.CreateMul(IV, Step);
Value *IndVar = Builder.CreateAdd(Span, Start);
+ if (InScan) {
+ scanInfo.iv = IndVar;
+ }
return BodyGenCB(Builder.saveIP(), IndVar);
};
LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d41489921bd13..47503b7846769 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -86,7 +86,9 @@ class OpenMPLoopInfoStackFrame
: public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
- llvm::CanonicalLoopInfo *loopInfo = nullptr;
+ // For constructs like scan, one Loop info frame can contain multiple
+ // Canonical Loops
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
};
/// Custom error class to signal translation errors that don't need reporting,
@@ -232,8 +234,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
op.getReductionSyms())
result = todo("reduction");
if (op.getReductionMod() &&
- op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
- result = todo("reduction with modifier");
+ op.getReductionMod().value() == omp::ReductionModifier::task)
+ result = todo("reduction with task modifier");
};
auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
@@ -383,15 +385,15 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
/// Find the loop information structure for the loop nest being translated. It
/// will return a `null` value unless called from the translation function for
/// a loop wrapper operation after successfully translating its body.
-static llvm::CanonicalLoopInfo *
-findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
- llvm::CanonicalLoopInfo *loopInfo = nullptr;
+static SmallVector<llvm::CanonicalLoopInfo *>
+findCurrentLoopInfos(LLVM::ModuleTranslation &moduleTranslation) {
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
[&](OpenMPLoopInfoStackFrame &frame) {
- loopInfo = frame.loopInfo;
+ loopInfos = frame.loopInfos;
return WalkResult::interrupt();
});
- return loopInfo;
+ return loopInfos;
}
/// Converts the given region that appears within an OpenMP dialect operation to
@@ -2258,26 +2260,61 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(regionBlock, opInst)))
return failure();
- builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
- llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
-
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
- ompBuilder->applyWorkshareLoop(
- ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
- convertToScheduleKind(schedule), chunk, isSimd,
- scheduleMod == omp::ScheduleModifier::monotonic,
- scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
- workshareLoopType);
-
- if (failed(handleError(wsloopIP, opInst)))
- return failure();
-
- // Process the reductions if required.
- if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
- allocaIP, reductionDecls,
- privateReductionVariables, isByRef)))
- return failure();
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+ findCurrentLoopInfos(moduleTranslation);
+ auto beforeLoopFinishIp = loopInfos.front()->getAfterIP();
+ auto afterLoopFinishIp = loopInfos.back()->getAfterIP();
+ bool isInScanRegion =
+ wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
+ mlir::omp::ReductionModifier::inscan);
+ if (isInScanRegion) {
+ builder.restoreIP(beforeLoopFinishIp);
+ SmallVector<OwningReductionGen> owningReductionGens;
+ SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+ collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
+ owningReductionGens, owningAtomicReductionGens,
+ privateReductionVariables, reductionInfos);
+ llvm::BasicBlock *cont = splitBB(builder, false, "omp.scan.loop.cont");
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP =
+ ompBuilder->emitScanReduction(builder.saveIP(), afterLoopFinishIp,
+ reductionInfos);
+ if (failed(handleError(redIP, opInst)))
+ return failure();
+ builder.restoreIP(*redIP);
+ builder.CreateBr(cont);
+ }
+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+ ompBuilder->applyWorkshareLoop(
+ ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+ convertToScheduleKind(schedule), chunk, isSimd,
+ scheduleMod == omp::ScheduleModifier::monotonic,
+ scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+ workshareLoopType);
+
+ if (failed(handleError(wsloopIP, opInst)))
+ return failure();
+ }
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+ if (isInScanRegion) {
+ SmallVector<Region *> reductionRegions;
+ llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
+ [](omp::DeclareReductionOp reductionDecl) {
+ return &reductionDecl.getCleanupRegion();
+ });
+ if (failed(inlineOmpRegionCleanup(
+ reductionRegions, privateReductionVariables, moduleTranslation,
+ builder, "omp.reduction.cleanup")))
+ return failure();
+ } else {
+ // Process the reductions if required.
+ if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
+ allocaIP, reductionDecls,
+ privateReductionVariables, isByRef)))
+ return failure();
+ }
return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
privateVarsInfo.llvmVars,
privateVarsInfo.privatizers);
@@ -2467,6 +2504,60 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
llvm_unreachable("Unknown ClauseOrderKind kind");
}
+static LogicalResult
+convertOmpScan(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ if (failed(checkImplementationStatus(opInst)))
+ return failure();
+ auto scanOp = cast<omp::ScanOp>(opInst);
+ bool isInclusive = scanOp.hasInclusiveVars();
+ SmallVector<llvm::Value *> llvmScanVars;
+ mlir::OperandRange mlirScanVars = scanOp.getInclusiveVars();
+ if (!isInclusive)
+ mlirScanVars = scanOp.getExclusiveVars();
+ for (auto val : mlirScanVars) {
+ llvm::Value *llvmVal = moduleTranslation.lookupValue(val);
+
+ llvmScanVars.push_back(llvmVal);
+ }
+ llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
+ findAllocaInsertPoint(builder, moduleTranslation);
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
+ moduleTranslation.getOpenMPBuilder()->createScan(
+ ompLoc, allocaIP, llvmScanVars, isInclusive);
+ if (failed(handleError(afterIP, opInst)))
+ return failure();
+
+ builder.restoreIP(*afterIP);
+
+ // TODO: The argument of LoopnestOp is stored into the index variable and this
+ // variable is used
+ // across scan operation. However that makes the mlir
+ // invalid.(`Intra-iteration dependences from a statement in the structured
+ // block sequence that precede a scan directive to a statement in the
+ // structured block sequence that follows a scan directive must not exist,
+ // except for dependences for the list items specified in an inclusive or
+ // exclusive clause.`). The argument of LoopNestOp need to be loaded again
+ // after ScanOp again so mlir generated is valid.
+ auto parentOp = scanOp->getParentOp();
+ auto loopOp = cast<omp::LoopNestOp>(parentOp);
+ if (loopOp) {
+ auto &firstBlock = *(scanOp->getParentRegion()->getBlocks()).begin();
+ auto &ins = *(firstBlock.begin());
+ if (isa<LLVM::StoreOp>(ins)) {
+ LLVM::StoreOp storeOp = dyn_cast<LLVM::StoreOp>(ins);
+ auto src = moduleTranslation.lookupValue(storeOp->getOperand(0));
+ if (src == moduleTranslation.lookupValue(
+ (loopOp.getRegion().getArguments())[0])) {
+ auto dest = moduleTranslation.lookupValue(storeOp->getOperand(1));
+ builder.CreateStore(src, dest);
+ }
+ }
+ }
+ return success();
+}
+
/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -2540,12 +2631,15 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
return failure();
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
- llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
- ompBuilder->applySimd(loopInfo, alignedVars,
- simdOp.getIfExpr()
- ? moduleTranslation.lookupValue(simdOp.getIfExpr())
- : nullptr,
- order, simdlen, safelen);
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+ findCurrentLoopInfos(moduleTranslation);
+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+ ompBuilder->applySimd(
+ loopInfo, alignedVars,
+ simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+ : nullptr,
+ order, simdlen, safelen);
+ }
return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
privateVarsInfo.llvmVars,
@@ -2612,16 +2706,52 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
ompLoc.DL);
computeIP = loopInfos.front()->getPreheaderIP();
}
+ if (auto wsloopOp = loopOp->getParentOfType<omp::WsloopOp>()) {
+ bool isInScanRegion =
+ wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
+ mlir::omp::ReductionModifier::inscan);
+ // TODOSCAN: Take care of loop and add asserts if required
+ if (isInScanRegion) {
+ llvm::Expected<SmallVector<llvm::CanonicalLoopInfo *>> loopResults =
+ ompBuilder->createCanonicalScanLoops(
+ loc, bodyGen, lowerBound, upperBound, step,
+ /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP, "loop",
+ isInScanRegion);
+
+ if (failed(handleError(loopResults, *loopOp)))
+ return failure();
+ auto beforeLoop = loopResults->front();
+ auto afterLoop = loopResults->back();
+ moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+ [&](OpenMPLoopInfoStackFrame &frame) {
+ frame.loopInfos.push_back(beforeLoop);
+ frame.loopInfos.push_back(afterLoop);
+ return WalkResult::interrupt();
+ });
+ builder.restoreIP(afterLoop->getAfterIP());
+ return success();
+ } else {
+ llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
+ ompBuilder->createCanonicalLoop(
+ loc, bodyGen, lowerBound, upperBound, step,
+ /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
- llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
- ompBuilder->createCanonicalLoop(
- loc, bodyGen, lowerBound, upperBound, step,
- /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
+ if (failed(handleError(loopResult, *loopOp)))
+ return failure();
- if (failed(handleError(loopResult, *loopOp)))
- return failure();
+ loopInfos.push_back(*loopResult);
+ }
+ } else {
+ llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
+ ompBuilder->createCanonicalLoop(
+ loc, bodyGen, lowerBound, upperBound, step,
+ /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
+
+ if (failed(handleError(loopResult, *loopOp)))
+ return failure();
- loopInfos.push_back(*loopResult);
+ loopInfos.push_back(*loopResult);
+ }
}
// Collapse loops. Store the insertion point because LoopInfos may get
@@ -2633,7 +2763,8 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
// after applying transformations.
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
[&](OpenMPLoopInfoStackFrame &frame) {
- frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
+ frame.loopInfos.push_back(
+ ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {}));
return WalkResult::interrupt();
});
@@ -4212,18 +4343,20 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
bool loopNeedsBarrier = false;
llvm::Value *chunk = nullptr;
- llvm::CanonicalLoopInfo *loopInfo =
- findCurrentLoopInfo(moduleTranslation);
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
- ompBuilder->applyWorkshareLoop(
- ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
- convertToScheduleKind(schedule), chunk, isSimd,
- scheduleMod == omp::ScheduleModifier::monotonic,
- scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
- workshareLoopType);
-
- if (!wsloopIP)
- return wsloopIP.takeError();
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+ findCurrentLoopInfos(moduleTranslation);
+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+ ompBuilder->applyWorkshareLoop(
+ ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+ convertToScheduleKind(schedule), chunk, isSimd,
+ scheduleMod == omp::ScheduleModifier::monotonic,
+ scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+ workshareLoopType);
+
+ if (!wsloopIP)
+ return wsloopIP.takeError();
+ }
}
if (failed(cleanupPrivateVars(builder, moduleTranslation,
@@ -5202,6 +5335,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
.Case([&](omp::WsloopOp) {
return convertOmpWsloop(*op, builder, moduleTranslation);
})
+ .Case([&](omp::ScanOp) {
+ return convertOmpScan(*op, builder, moduleTranslation);
+ })
.Case([&](omp::SimdOp) {
return convertOmpSimd(*op, builder, moduleTranslation);
})
diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-scan-out b/mlir/test/Target/LLVMIR/openmp-reduction-scan-out
new file mode 100644
index 0000000000000..33ee30ca1fc10
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-reduction-scan-out
@@ -0,0 +1,364 @@
+; ModuleID = 'LLVMDialectModule'
+source_filename = "LLVMDialectModule"
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+%struct.ident_t = type { i32, i32, i32, i32, ptr }
+
+ at _QFEa = internal global [100 x i32] zeroinitializer
+ at _QFEb = internal global [100 x i32] zeroinitializer
+ at _QFECn = internal constant i32 100
+ at 0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
+ at 1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
+ at 2 = private unnamed_addr constant %struct.ident_t { i32 0, i32 34, i32 0, i32 22, ptr @0 }, align 8
+ at 3 = private unnamed_addr constant %struct.ident_t { i32 0, i32 66, i32 0, i32 22, ptr @0 }, align 8
+
+define void @_QQmain() {
+ %structArg = alloca { ptr }, align 8
+ %1 = alloca i32, i64 1, align 4
+ %2 = alloca i32, i64 1, align 4
+ %3 = alloca i32, i64 1, align 4
+ %4 = alloca i32, i64 1, align 4
+ store i32 0, ptr %3, align 4
+ br label %5
+
+5: ; preds = %9, %0
+ %6 = phi i32 [ %18, %9 ], [ 1, %0 ]
+ %7 = phi i64 [ %19, %9 ], [ 100, %0 ]
+ %8 = icmp sgt i64 %7, 0
+ br i1 %8, label %9, label %20
+
+9: ; preds = %5
+ store i32 %6, ptr %4, align 4
+ %10 = load i32, ptr %4, align 4
+ %11 = sext i32 %10 to i64
+ %12 = sub nsw i64 %11, 1
+ %13 = mul nsw i64 %12, 1
+ %14 = mul nsw i64 %13, 1
+ %15 = add nsw i64 %14, 0
+ %16 = getelementptr i32, ptr @_QFEa, i64 %15
+ store i32 %10, ptr %16, align 4
+ %17 = load i32, ptr %4, align 4
+ %18 = add i32 %17, 1
+ %19 = sub i64 %7, 1
+ br label %5
+
+20: ; preds = %5
+ store i32 %6, ptr %4, align 4
+ %omp_global_thread_num = call i32 @__kmpc_global_thread_num(ptr @1)
+ br label %omp_parallel
+
+omp_parallel: ; preds = %20
+ %gep_ = getelementptr { ptr }, ptr %structArg, i32 0, i32 0
+ store ptr %3, ptr %gep_, align 8
+ call void (ptr, i32, ptr, ...) @__kmpc_fork_call(ptr @1, i32 1, ptr @_QQmain..omp_par, ptr %structArg)
+ br label %omp.par.exit
+
+omp.par.exit: ; preds = %omp_parallel
+ ret void
+}
+
+; Function Attrs: nounwind
+define internal void @_QQmain..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+omp.par.entry:
+ %gep_ = getelementptr { ptr }, ptr %0, i32 0, i32 0
+ %loadgep_ = load ptr, ptr %gep_, align 8
+ %p.lastiter28 = alloca i32, align 4
+ %p.lowerbound29 = alloca i32, align 4
+ %p.upperbound30 = alloca i32, align 4
+ %p.stride31 = alloca i32, align 4
+ %p.lastiter = alloca i32, align 4
+ %p.lowerbound = alloca i32, align 4
+ %p.upperbound = alloca i32, align 4
+ %p.stride = alloca i32, align 4
+ %tid.addr.local = alloca i32, align 4
+ %1 = load i32, ptr %tid.addr, align 4
+ store i32 %1, ptr %tid.addr.local, align 4
+ %tid = load i32, ptr %tid.addr.local, align 4
+ %2 = alloca i32, align 4
+ br label %omp.region.after_alloca2
+
+omp.region.after_alloca2: ; preds = %omp.par.entry
+ %vla = alloca i32, i32 100, align 4
+ br label %omp.region.after_alloca
+
+omp.region.after_alloca: ; preds = %omp.region.after_alloca2
+ br label %omp.par.region
+
+omp.par.region: ; preds = %omp.region.after_alloca
+ br label %omp.par.region1
+
+omp.par.region1: ; preds = %omp.par.region
+ %3 = alloca i32, i64 1, align 4
+ br label %omp.reduction.init
+
+omp.reduction.init: ; preds = %omp.par.region1
+ store i32 0, ptr %2, align 4
+ br label %omp.wsloop.region
+
+omp.wsloop.region: ; preds = %omp.reduction.init
+ br label %omp_loop.preheader
+
+omp_loop.preheader: ; preds = %omp.wsloop.region
+ store i32 0, ptr %p.lowerbound, align 4
+ store i32 99, ptr %p.upperbound, align 4
+ store i32 1, ptr %p.stride, align 4
+ %omp_global_thread_num26 = call i32 @__kmpc_global_thread_num(ptr @1)
+ call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num26, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0)
+ %4 = load i32, ptr %p.lowerbound, align 4
+ %5 = load i32, ptr %p.upperbound, align 4
+ %6 = sub i32 %5, %4
+ %7 = add i32 %6, 1
+ br label %omp_loop.header
+
+omp_loop.header: ; preds = %omp_loop.inc, %omp_loop.preheader
+ %omp_loop.iv = phi i32 [ 0, %omp_loop.preheader ], [ %omp_loop.next, %omp_loop.inc ]
+ br label %omp_loop.cond
+
+omp_loop.cond: ; preds = %omp_loop.header
+ %omp_loop.cmp = icmp ult i32 %omp_loop.iv, %7
+ br i1 %omp_loop.cmp, label %omp_loop.body, label %omp_loop.exit
+
+omp_loop.exit: ; preds = %omp_loop.cond
+ call void @__kmpc_for_static_fini(ptr @1, i32 %omp_global_thread_num26)
+ %omp_global_thread_num27 = call i32 @__kmpc_global_thread_num(ptr @1)
+ call void @__kmpc_barrier(ptr @3, i32 %omp_global_thread_num27)
+ br label %omp_loop.after
+
+omp_loop.after: ; preds = %omp_loop.exit
+ %8 = call double @llvm.log2.f64(double 1.000000e+02) #0
+ %9 = call double @llvm.ceil.f64(double %8) #0
+ %10 = fptoui double %9 to i32
+ br label %omp.outer.log.scan.body
+
+omp.outer.log.scan.body: ; preds = %omp.inner.log.scan.exit, %omp_loop.after
+ %11 = phi i32 [ 0, %omp_loop.after ], [ %14, %omp.inner.log.scan.exit ]
+ %12 = phi i32 [ 1, %omp_loop.after ], [ %15, %omp.inner.log.scan.exit ]
+ %13 = icmp uge i32 99, %12
+ br i1 %13, label %omp.inner.log.scan.body, label %omp.inner.log.scan.exit
+
+omp.inner.log.scan.exit: ; preds = %omp.inner.log.scan.body, %omp.outer.log.scan.body
+ %14 = add nuw i32 %11, 1
+ %15 = shl nuw i32 %12, 1
+ %16 = icmp ne i32 %14, %10
+ br i1 %16, label %omp.outer.log.scan.body, label %omp.outer.log.scan.exit
+
+omp.outer.log.scan.exit: ; preds = %omp.inner.log.scan.exit
+ %omp_global_thread_num24 = call i32 @__kmpc_global_thread_num(ptr @1)
+ call void @__kmpc_barrier(ptr @2, i32 %omp_global_thread_num24)
+ br label %omp.scan.loop.cont
+
+omp.scan.loop.cont: ; preds = %omp.outer.log.scan.exit
+ br label %omp_loop.preheader5
+
+omp_loop.preheader5: ; preds = %omp.scan.loop.cont
+ store i32 0, ptr %p.lowerbound29, align 4
+ store i32 99, ptr %p.upperbound30, align 4
+ store i32 1, ptr %p.stride31, align 4
+ %omp_global_thread_num32 = call i32 @__kmpc_global_thread_num(ptr @1)
+ call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num32, i32 34, ptr %p.lastiter28, ptr %p.lowerbound29, ptr %p.upperbound30, ptr %p.stride31, i32 1, i32 0)
+ %17 = load i32, ptr %p.lowerbound29, align 4
+ %18 = load i32, ptr %p.upperbound30, align 4
+ %19 = sub i32 %18, %17
+ %20 = add i32 %19, 1
+ br label %omp_loop.header6
+
+omp_loop.header6: ; preds = %omp_loop.inc9, %omp_loop.preheader5
+ %omp_loop.iv12 = phi i32 [ 0, %omp_loop.preheader5 ], [ %omp_loop.next14, %omp_loop.inc9 ]
+ br label %omp_loop.cond7
+
+omp_loop.cond7: ; preds = %omp_loop.header6
+ %omp_loop.cmp13 = icmp ult i32 %omp_loop.iv12, %20
+ br i1 %omp_loop.cmp13, label %omp_loop.body8, label %omp_loop.exit10
+
+omp_loop.exit10: ; preds = %omp_loop.cond7
+ call void @__kmpc_for_static_fini(ptr @1, i32 %omp_global_thread_num32)
+ %omp_global_thread_num33 = call i32 @__kmpc_global_thread_num(ptr @1)
+ call void @__kmpc_barrier(ptr @3, i32 %omp_global_thread_num33)
+ br label %omp_loop.after11
+
+omp_loop.after11: ; preds = %omp_loop.exit10
+ %arrayOffset25 = getelementptr inbounds i32, ptr %vla, i32 100
+ %21 = load i32, ptr %arrayOffset25, align 4
+ store i32 %21, ptr %loadgep_, align 4
+ br label %omp.region.cont3
+
+omp.region.cont3: ; preds = %omp_loop.after11
+ br label %omp.region.cont
+
+omp.region.cont: ; preds = %omp.region.cont3
+ br label %omp.par.pre_finalize
+
+omp.par.pre_finalize: ; preds = %omp.region.cont
+ br label %omp.par.exit.exitStub
+
+omp_loop.body8: ; preds = %omp_loop.cond7
+ %22 = add i32 %omp_loop.iv12, %17
+ %23 = mul i32 %22, 1
+ %24 = add i32 %23, 1
+ br label %omp.inscan.dispatch17
+
+omp.inscan.dispatch17: ; preds = %omp_loop.body8
+ store i32 0, ptr %2, align 4
+ %arrayOffset20 = getelementptr inbounds i32, ptr %vla, i32 %24
+ %25 = load i32, ptr %arrayOffset20, align 4
+ store i32 %25, ptr %2, align 4
+ br i1 true, label %omp.after.scan.bb21, label %omp.before.scan.bb15
+
+omp.before.scan.bb15: ; preds = %omp.inscan.dispatch17
+ br label %omp.loop_nest.region19
+
+omp.loop_nest.region19: ; preds = %omp.before.scan.bb15
+ store i32 %24, ptr %3, align 4
+ %26 = load i32, ptr %2, align 4
+ %27 = load i32, ptr %3, align 4
+ %28 = sext i32 %27 to i64
+ %29 = sub nsw i64 %28, 1
+ %30 = mul nsw i64 %29, 1
+ %31 = mul nsw i64 %30, 1
+ %32 = add nsw i64 %31, 0
+ %33 = getelementptr i32, ptr @_QFEa, i64 %32
+ %34 = load i32, ptr %33, align 4
+ %35 = add i32 %26, %34
+ store i32 %35, ptr %2, align 4
+ br label %omp.scan.loop.exit16
+
+omp.scan.loop.exit16: ; preds = %omp.loop_nest.region19, %omp.region.cont18
+ br label %omp_loop.inc9
+
+omp_loop.inc9: ; preds = %omp.scan.loop.exit16
+ %omp_loop.next14 = add nuw i32 %omp_loop.iv12, 1
+ br label %omp_loop.header6
+
+omp.after.scan.bb21: ; preds = %omp.inscan.dispatch17
+ store i32 %24, ptr %3, align 4
+ %36 = load i32, ptr %2, align 4
+ %37 = load i32, ptr %3, align 4
+ %38 = sext i32 %37 to i64
+ %39 = sub nsw i64 %38, 1
+ %40 = mul nsw i64 %39, 1
+ %41 = mul nsw i64 %40, 1
+ %42 = add nsw i64 %41, 0
+ %43 = getelementptr i32, ptr @_QFEb, i64 %42
+ store i32 %36, ptr %43, align 4
+ br label %omp.region.cont18
+
+omp.region.cont18: ; preds = %omp.after.scan.bb21
+ br label %omp.scan.loop.exit16
+
+omp.inner.log.scan.body: ; preds = %omp.inner.log.scan.body, %omp.outer.log.scan.body
+ %44 = phi i32 [ 99, %omp.outer.log.scan.body ], [ %50, %omp.inner.log.scan.body ]
+ %45 = add i32 %44, 1
+ %arrayOffset22 = getelementptr inbounds i32, ptr %vla, i32 %45
+ %46 = sub nuw i32 %45, %12
+ %arrayOffset23 = getelementptr inbounds i32, ptr %vla, i32 %46
+ %47 = load i32, ptr %arrayOffset22, align 4
+ %48 = load i32, ptr %arrayOffset23, align 4
+ %49 = add i32 %47, %48
+ store i32 %49, ptr %arrayOffset22, align 4
+ %50 = sub nuw i32 %44, 1
+ %51 = icmp uge i32 %50, %12
+ br i1 %51, label %omp.inner.log.scan.body, label %omp.inner.log.scan.exit
+
+omp_loop.body: ; preds = %omp_loop.cond
+ %52 = add i32 %omp_loop.iv, %4
+ %53 = mul i32 %52, 1
+ %54 = add i32 %53, 1
+ br label %omp.inscan.dispatch
+
+omp.inscan.dispatch: ; preds = %omp_loop.body
+ store i32 0, ptr %2, align 4
+ br i1 true, label %omp.before.scan.bb, label %omp.after.scan.bb
+
+omp.after.scan.bb: ; preds = %omp.inscan.dispatch
+ store i32 %54, ptr %3, align 4
+ %55 = load i32, ptr %2, align 4
+ %56 = load i32, ptr %3, align 4
+ %57 = sext i32 %56 to i64
+ %58 = sub nsw i64 %57, 1
+ %59 = mul nsw i64 %58, 1
+ %60 = mul nsw i64 %59, 1
+ %61 = add nsw i64 %60, 0
+ %62 = getelementptr i32, ptr @_QFEb, i64 %61
+ store i32 %55, ptr %62, align 4
+ br label %omp.region.cont4
+
+omp.region.cont4: ; preds = %omp.after.scan.bb
+ br label %omp.scan.loop.exit
+
+omp.scan.loop.exit: ; preds = %omp.loop_nest.region, %omp.region.cont4
+ br label %omp_loop.inc
+
+omp_loop.inc: ; preds = %omp.scan.loop.exit
+ %omp_loop.next = add nuw i32 %omp_loop.iv, 1
+ br label %omp_loop.header
+
+omp.before.scan.bb: ; preds = %omp.inscan.dispatch
+ br label %omp.loop_nest.region
+
+omp.loop_nest.region: ; preds = %omp.before.scan.bb
+ store i32 %54, ptr %3, align 4
+ %63 = load i32, ptr %2, align 4
+ %64 = load i32, ptr %3, align 4
+ %65 = sext i32 %64 to i64
+ %66 = sub nsw i64 %65, 1
+ %67 = mul nsw i64 %66, 1
+ %68 = mul nsw i64 %67, 1
+ %69 = add nsw i64 %68, 0
+ %70 = getelementptr i32, ptr @_QFEa, i64 %69
+ %71 = load i32, ptr %70, align 4
+ %72 = add i32 %63, %71
+ store i32 %72, ptr %2, align 4
+ %arrayOffset = getelementptr inbounds i32, ptr %vla, i32 %54
+ %73 = load i32, ptr %2, align 4
+ store i32 %73, ptr %arrayOffset, align 4
+ br label %omp.scan.loop.exit
+
+omp.par.exit.exitStub: ; preds = %omp.par.pre_finalize
+ ret void
+}
+
+declare void @_FortranAProgramStart(i32, ptr, ptr, ptr)
+
+declare void @_FortranAProgramEndStatement()
+
+define i32 @main(i32 %0, ptr %1, ptr %2) {
+ call void @_FortranAProgramStart(i32 %0, ptr %1, ptr %2, ptr null)
+ call void @_QQmain()
+ call void @_FortranAProgramEndStatement()
+ ret i32 0
+}
+
+; Function Attrs: nounwind
+declare i32 @__kmpc_global_thread_num(ptr) #0
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare double @llvm.log2.f64(double) #1
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare double @llvm.ceil.f64(double) #1
+
+; Function Attrs: convergent nounwind
+declare void @__kmpc_barrier(ptr, i32) #2
+
+; Function Attrs: nounwind
+declare void @__kmpc_for_static_init_4u(ptr, i32, i32, ptr, ptr, ptr, ptr, i32, i32) #0
+
+; Function Attrs: nounwind
+declare void @__kmpc_for_static_fini(ptr, i32) #0
+
+; Function Attrs: nounwind
+declare !callback !3 void @__kmpc_fork_call(ptr, i32, ptr, ...) #0
+
+attributes #0 = { nounwind }
+attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
+attributes #2 = { convergent nounwind }
+
+!llvm.module.flags = !{!0, !1}
+!llvm.ident = !{!2}
+
+!0 = !{i32 7, !"openmp", i32 11}
+!1 = !{i32 2, !"Debug Info Version", i32 3}
+!2 = !{!"flang version 20.0.0 (git at github.com:anchuraj/llvm-project.git aedc369685b22e2b8f7413557d292a78637f563b)"}
+!3 = !{!4}
+!4 = !{i64 2, i64 -1, i64 -1, i1 true}
diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
new file mode 100644
index 0000000000000..b0e2b9cb7c97c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
@@ -0,0 +1,66 @@
+omp.declare_reduction @add_reduction_i32 : i32 init {
+^bb0(%arg0: i32):
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ omp.yield(%0 : i32)
+} combiner {
+^bb0(%arg0: i32, %arg1: i32):
+ %0 = llvm.add %arg0, %arg1 : i32
+ omp.yield(%0 : i32)
+}
+// CHECK-LABEL: @simple_reduction
+llvm.func @scan_reduction() {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x i32 {bindc_name = "z"} : (i64) -> !llvm.ptr
+ %2 = llvm.mlir.constant(1 : i64) : i64
+ %3 = llvm.alloca %2 x i32 {bindc_name = "y"} : (i64) -> !llvm.ptr
+ %4 = llvm.mlir.constant(1 : i64) : i64
+ %5 = llvm.alloca %4 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+ %6 = llvm.mlir.constant(1 : i64) : i64
+ %7 = llvm.alloca %6 x i32 {bindc_name = "k"} : (i64) -> !llvm.ptr
+ %8 = llvm.mlir.constant(0 : index) : i64
+ %9 = llvm.mlir.constant(1 : index) : i64
+ %10 = llvm.mlir.constant(100 : i32) : i32
+ %11 = llvm.mlir.constant(1 : i32) : i32
+ %12 = llvm.mlir.constant(0 : i32) : i32
+ %13 = llvm.mlir.constant(100 : index) : i64
+ %14 = llvm.mlir.addressof @_QFEa : !llvm.ptr
+ %15 = llvm.mlir.addressof @_QFEb : !llvm.ptr
+ omp.parallel {
+ %37 = llvm.mlir.constant(1 : i64) : i64
+ %38 = llvm.alloca %37 x i32 {bindc_name = "k", pinned} : (i64) -> !llvm.ptr
+ %39 = llvm.mlir.constant(1 : i64) : i64
+ omp.wsloop reduction(mod: inscan, @add_reduction_i32 %5 -> %arg0 : !llvm.ptr) {
+ omp.loop_nest (%arg1) : i32 = (%11) to (%10) inclusive step (%11) {
+ llvm.store %arg1, %38 : i32, !llvm.ptr
+ %40 = llvm.load %arg0 : !llvm.ptr -> i32
+ %41 = llvm.load %38 : !llvm.ptr -> i32
+ %42 = llvm.sext %41 : i32 to i64
+ %50 = llvm.getelementptr %14[%42] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+ %51 = llvm.load %50 : !llvm.ptr -> i32
+ %52 = llvm.add %40, %51 : i32
+ llvm.store %52, %arg0 : i32, !llvm.ptr
+ omp.scan inclusive(%arg0 : !llvm.ptr)
+ %53 = llvm.load %arg0 : !llvm.ptr -> i32
+ %54 = llvm.load %38 : !llvm.ptr -> i32
+ %55 = llvm.sext %54 : i32 to i64
+ %63 = llvm.getelementptr %15[%55] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+ llvm.store %53, %63 : i32, !llvm.ptr
+ omp.yield
+ }
+ }
+ omp.terminator
+ }
+ llvm.return
+}
+llvm.mlir.global internal @_QFEa() {addr_space = 0 : i32} : !llvm.array<100 x i32> {
+ %0 = llvm.mlir.zero : !llvm.array<100 x i32>
+ llvm.return %0 : !llvm.array<100 x i32>
+}
+llvm.mlir.global internal @_QFEb() {addr_space = 0 : i32} : !llvm.array<100 x i32> {
+ %0 = llvm.mlir.zero : !llvm.array<100 x i32>
+ llvm.return %0 : !llvm.array<100 x i32>
+}
+llvm.mlir.global internal constant @_QFECn() {addr_space = 0 : i32} : i32 {
+ %0 = llvm.mlir.constant(100 : i32) : i32
+ llvm.return %0 : i32
+}
More information about the Mlir-commits
mailing list