[Mlir-commits] [flang] [mlir] [MLIR][LLVMIR] Adding scan lowering to llvm from mlie (PR #165788)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 31 11:44:21 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
Author: Anchu Rajendran S (anchuraj)
<details>
<summary>Changes</summary>
Scan reductions are supported in OpenMP with the the help of scan directive. Reduction clause of the for workshare loop/simd directive takes an `inscan` modifier if scan reduction is specified. With an `inscan` modifier, the body of the directive should specify a `scan` directive. This PR implements the lowering logic for scan reductions in workshare loops of OpenMP. OpenMPIRBuilder support can be found in https://github.com/llvm/llvm-project/pull/136035. Support for nested loops/ exclusive clause is not done in this PR
---
Patch is 37.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165788.diff
4 Files Affected:
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+34-5)
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+303-77)
- (added) mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir (+130)
- (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (+36-2)
``````````diff
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index f86ee01355104..5d82466889b1e 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2326,12 +2326,41 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
static mlir::omp::ScanOp
genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx, mlir::Location loc,
- const ConstructQueue &queue, ConstructQueue::const_iterator item) {
+ semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+ mlir::Location loc, const ConstructQueue &queue,
+ ConstructQueue::const_iterator item) {
mlir::omp::ScanOperands clauseOps;
genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
- return mlir::omp::ScanOp::create(converter.getFirOpBuilder(),
- converter.getCurrentLocation(), clauseOps);
+ mlir::omp::ScanOp scanOp = mlir::omp::ScanOp::create(
+ converter.getFirOpBuilder(), converter.getCurrentLocation(), clauseOps);
+ // If there are nested loops all indices should be loaded after
+ // the scan construct as otherwise, it would result in using the index
+ // variable across scan directive.
+ // (`Intra-iteration dependences from a statement in the structured
+ // block sequence that precede a scan directive to a statement in the
+ // structured block sequence that follows a scan directive must not exist,
+ // except for dependences for the list items specified in an inclusive or
+ // exclusive clause.`).
+ // TODO: If there are nested loops, it is not handled.
+ mlir::omp::LoopNestOp loopNestOp =
+ scanOp->getParentOfType<mlir::omp::LoopNestOp>();
+ assert(loopNestOp.getNumLoops() == 1 &&
+ "Scan directive inside nested do loops is not handled yet.");
+ mlir::Region ®ion = loopNestOp->getRegion(0);
+ mlir::Value indexVal = fir::getBase(region.getArgument(0));
+ lower::pft::Evaluation *doConstructEval = eval.parentConstruct;
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ lower::pft::Evaluation *doLoop = &doConstructEval->getFirstNestedEvaluation();
+ auto *doStmt = doLoop->getIf<parser::NonLabelDoStmt>();
+ assert(doStmt && "Expected do loop to be in the nested evaluation");
+ const auto &loopControl =
+ std::get<std::optional<parser::LoopControl>>(doStmt->t);
+ const parser::LoopControl::Bounds *bounds =
+ std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
+ mlir::Operation *storeOp =
+ setLoopVar(converter, loc, indexVal, bounds->name.thing.symbol);
+ firOpBuilder.setInsertionPointAfter(storeOp);
+ return scanOp;
}
static mlir::omp::SectionsOp
@@ -3416,7 +3445,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
loc, queue, item);
break;
case llvm::omp::Directive::OMPD_scan:
- newOp = genScanOp(converter, symTable, semaCtx, loc, queue, item);
+ newOp = genScanOp(converter, symTable, semaCtx, eval, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_section:
llvm_unreachable("genOMPDispatch: OMPD_section");
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 1e2099d6cc1b2..9db269c4f8756 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -37,6 +37,7 @@
#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
+#include <cassert>
#include <cstdint>
#include <iterator>
#include <numeric>
@@ -77,6 +78,22 @@ class OpenMPAllocaStackFrame
llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
};
+/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
+/// insertion points for allocas of parent of the current parallel region. The
+/// insertion point is used to allocate variables to be share by the threads
+/// executing the parallel region. Lowering of scan reduction requires declaring
+/// shared pointers to the temporary buffer to perform scan reduction.
+class OpenMPParallelAllocaStackFrame
+ : public StateStackFrameBase<OpenMPParallelAllocaStackFrame> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPParallelAllocaStackFrame)
+
+ explicit OpenMPParallelAllocaStackFrame(
+ llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
+ : allocaInsertPoint(allocaIP) {}
+ llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
+};
+
/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
/// collapsed canonical loop information corresponding to an \c omp.loop_nest
/// operation.
@@ -84,7 +101,13 @@ class OpenMPLoopInfoStackFrame
: public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
- llvm::CanonicalLoopInfo *loopInfo = nullptr;
+ /// For constructs like scan, one LoopInfo frame can contain multiple
+ /// Canonical Loops as a single openmpLoopNestOp will be split into input
+ /// loop and scan loop.
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
+ llvm::ScanInfo *scanInfo;
+ llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+ new llvm::DenseMap<llvm::Value *, llvm::Type *>();
};
/// Custom error class to signal translation errors that don't need reporting,
@@ -323,6 +346,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.getDistScheduleChunkSize())
result = todo("dist_schedule with chunk_size");
};
+ auto checkExclusive = [&todo](auto op, LogicalResult &result) {
+ if (!op.getExclusiveVars().empty())
+ result = todo("exclusive");
+ };
auto checkHint = [](auto op, LogicalResult &) {
if (op.getHint())
op.emitWarning("hint clause discarded");
@@ -371,9 +398,14 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (!op.getReductionVars().empty() || op.getReductionByref() ||
op.getReductionSyms())
result = todo("reduction");
- if (op.getReductionMod() &&
- op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
- result = todo("reduction with modifier");
+ if (op.getReductionMod()) {
+ if (isa<omp::WsloopOp>(op)) {
+ if (op.getReductionMod().value() == omp::ReductionModifier::task)
+ result = todo("reduction with task modifier");
+ } else {
+ result = todo("reduction with modifier");
+ }
+ }
};
auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
@@ -397,6 +429,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkOrder(op, result);
})
.Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
+ .Case([&](omp::ScanOp op) { checkExclusive(op, result); })
.Case([&](omp::SectionsOp op) {
checkAllocate(op, result);
checkPrivate(op, result);
@@ -531,15 +564,59 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
/// Find the loop information structure for the loop nest being translated. It
/// will return a `null` value unless called from the translation function for
/// a loop wrapper operation after successfully translating its body.
-static llvm::CanonicalLoopInfo *
-findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
- llvm::CanonicalLoopInfo *loopInfo = nullptr;
+static SmallVector<llvm::CanonicalLoopInfo *>
+findCurrentLoopInfos(LLVM::ModuleTranslation &moduleTranslation) {
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
+ moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+ [&](OpenMPLoopInfoStackFrame &frame) {
+ loopInfos = frame.loopInfos;
+ return WalkResult::interrupt();
+ });
+ return loopInfos;
+}
+
+// LoopFrame stores the scaninfo which is used for scan reduction.
+// Upon encountering an `inscan` reduction modifier, `scanInfoInitialize`
+// initializes the ScanInfo and is used when scan directive is encountered
+// in the body of the loop nest.
+static llvm::ScanInfo *
+findScanInfo(LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::ScanInfo *scanInfo;
+ moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+ [&](OpenMPLoopInfoStackFrame &frame) {
+ scanInfo = frame.scanInfo;
+ return WalkResult::interrupt();
+ });
+ return scanInfo;
+}
+
+// The types of reduction vars are used for lowering scan directive which
+// appears in the body of the loop. The types are stored in loop frame when
+// reduction clause is encountered and is used when scan directive is
+// encountered.
+static llvm::DenseMap<llvm::Value *, llvm::Type *> *
+findReductionVarTypes(LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType = nullptr;
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
[&](OpenMPLoopInfoStackFrame &frame) {
- loopInfo = frame.loopInfo;
+ reductionVarToType = frame.reductionVarToType;
return WalkResult::interrupt();
});
- return loopInfo;
+ return reductionVarToType;
+}
+
+// Scan reduction requires a shared buffer to be allocated to perform reduction.
+// ParallelAllocaStackFrame holds the allocaIP where shared allocation can be
+// done.
+static llvm::OpenMPIRBuilder::InsertPointTy
+findParallelAllocaIP(LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder::InsertPointTy parallelAllocaIP;
+ moduleTranslation.stackWalk<OpenMPParallelAllocaStackFrame>(
+ [&](OpenMPParallelAllocaStackFrame &frame) {
+ parallelAllocaIP = frame.allocaInsertPoint;
+ return WalkResult::interrupt();
+ });
+ return parallelAllocaIP;
}
/// Converts the given region that appears within an OpenMP dialect operation to
@@ -1254,11 +1331,17 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
for (auto [data, addr] : deferredStores)
builder.CreateStore(data, addr);
+ llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+ findReductionVarTypes(moduleTranslation);
// Before the loop, store the initial values of reductions into reduction
// variables. Although this could be done after allocas, we don't want to mess
// up with the alloca insertion point.
for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
SmallVector<llvm::Value *, 1> phis;
+ llvm::Type *reductionType =
+ moduleTranslation.convertType(reductionDecls[i].getType());
+ if (reductionVarToType != nullptr)
+ (*reductionVarToType)[privateReductionVariables[i]] = reductionType;
// map block argument to initializer region
mapInitializationArgs(op, moduleTranslation, reductionDecls,
@@ -1330,15 +1413,20 @@ static void collectReductionInfo(
// Collect the reduction information.
reductionInfos.reserve(numReductions);
+ llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+ findReductionVarTypes(moduleTranslation);
for (unsigned i = 0; i < numReductions; ++i) {
llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
if (owningAtomicReductionGens[i])
atomicGen = owningAtomicReductionGens[i];
llvm::Value *variable =
moduleTranslation.lookupValue(loop.getReductionVars()[i]);
+ llvm::Type *reductionType =
+ moduleTranslation.convertType(reductionDecls[i].getType());
+ if (reductionVarToType != nullptr)
+ (*reductionVarToType)[privateReductionVariables[i]] = reductionType;
reductionInfos.push_back(
- {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
- privateReductionVariables[i],
+ {reductionType, variable, privateReductionVariables[i],
/*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
owningReductionGens[i],
/*ReductionGenClang=*/nullptr, atomicGen});
@@ -2543,6 +2631,9 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
bool isSimd = wsloopOp.getScheduleSimd();
bool loopNeedsBarrier = !wsloopOp.getNowait();
+ bool isInScanRegion =
+ wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
+ mlir::omp::ReductionModifier::inscan);
// The only legal way for the direct parent to be omp.distribute is that this
// represents 'distribute parallel do'. Otherwise, this is a regular
@@ -2574,20 +2665,81 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(regionBlock, opInst)))
return failure();
- llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+ findCurrentLoopInfos(moduleTranslation);
+
+ const auto &&wsloopCodeGen = [&](llvm::CanonicalLoopInfo *loopInfo,
+ bool noLoopMode, bool inputScanLoop) {
+ bool emitLinearVarInit = !isInScanRegion || inputScanLoop;
+ // Emit Initialization and Update IR for linear variables
+ if (emitLinearVarInit && !wsloopOp.getLinearVars().empty()) {
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+ linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+ loopInfo->getPreheader());
+ if (failed(handleError(afterBarrierIP, *loopOp)))
+ return failure();
+ builder.restoreIP(*afterBarrierIP);
+ linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+ loopInfo->getIndVar());
+ }
+ bool emitLinearVarFinalize = !isInScanRegion || !inputScanLoop;
+ if (emitLinearVarFinalize)
+ linearClauseProcessor.outlineLinearFinalizationBB(builder,
+ loopInfo->getExit());
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+ ompBuilder->applyWorkshareLoop(
+ ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+ convertToScheduleKind(schedule), chunk, isSimd,
+ scheduleMod == omp::ScheduleModifier::monotonic,
+ scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+ workshareLoopType, noLoopMode);
+
+ if (failed(handleError(wsloopIP, opInst)))
+ return failure();
- // Emit Initialization and Update IR for linear variables
- if (!wsloopOp.getLinearVars().empty()) {
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
- linearClauseProcessor.initLinearVar(builder, moduleTranslation,
- loopInfo->getPreheader());
- if (failed(handleError(afterBarrierIP, *loopOp)))
+ // Emit finalization and in-place rewrites for linear vars.
+ if (emitLinearVarFinalize && !wsloopOp.getLinearVars().empty()) {
+ llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+ if (loopInfo->getLastIter())
+ return failure();
+ // assert(loopInfo->getLastIter() &&
+ // "`lastiter` in CanonicalLoopInfo is nullptr");
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+ linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
+ loopInfo->getLastIter());
+ if (failed(handleError(afterBarrierIP, *loopOp)))
+ return failure();
+ for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
+ linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
+ index);
+ builder.restoreIP(oldIP);
+ }
+ if (!inputScanLoop || !isInScanRegion)
+ popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
+
+ return llvm::success();
+ };
+
+ if (isInScanRegion) {
+ auto inputLoopFinishIp = loopInfos.front()->getAfterIP();
+ builder.restoreIP(inputLoopFinishIp);
+ SmallVector<OwningReductionGen> owningReductionGens;
+ SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+ collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
+ owningReductionGens, owningAtomicReductionGens,
+ privateReductionVariables, reductionInfos);
+ llvm::BasicBlock *cont = splitBB(builder, false, "omp.scan.loop.cont");
+ llvm::ScanInfo *scanInfo = findScanInfo(moduleTranslation);
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP =
+ ompBuilder->emitScanReduction(builder.saveIP(), reductionInfos,
+ scanInfo);
+ if (failed(handleError(redIP, opInst)))
return failure();
- builder.restoreIP(*afterBarrierIP);
- linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
- loopInfo->getIndVar());
- linearClauseProcessor.outlineLinearFinalizationBB(builder,
- loopInfo->getExit());
+
+ builder.restoreIP(*redIP);
+ builder.CreateBr(cont);
}
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
@@ -2612,42 +2764,34 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
}
}
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
- ompBuilder->applyWorkshareLoop(
- ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
- convertToScheduleKind(schedule), chunk, isSimd,
- scheduleMod == omp::ScheduleModifier::monotonic,
- scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
- workshareLoopType, noLoopMode);
-
- if (failed(handleError(wsloopIP, opInst)))
- return failure();
-
- // Emit finalization and in-place rewrites for linear vars.
- if (!wsloopOp.getLinearVars().empty()) {
- llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
- assert(loopInfo->getLastIter() &&
- "`lastiter` in CanonicalLoopInfo is nullptr");
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
- linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
- loopInfo->getLastIter());
- if (failed(handleError(afterBarrierIP, *loopOp)))
+ // For Scan loops input loop need not pop cancellation CB and hence, it is set
+ // false for the first loop
+ bool inputScanLoop = isInScanRegion;
+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+ if (failed(wsloopCodeGen(loopInfo, noLoopMode, inputScanLoop)))
return failure();
- for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
- linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
- index);
- builder.restoreIP(oldIP);
+ inputScanLoop = false;
}
- // Set the correct branch target for task cancellation
- popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
-
- // Process the reductions if required.
- if (failed(createReductionsAndCleanup(
- wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
- privateReductionVariables, isByRef, wsloopOp.getNowait(),
- /*isTeamsReduction=*/false)))
- return failure();
+ // todo: change builder.saveIP to wsLoopIP
+ if (isInScanRegion) {
+ SmallVector<Region *> reductionRegions;
+ llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
+ [](omp::DeclareReductionOp reductionDecl) {
+ return &reductionDecl.getCleanupRegion();
+ });
+ if (failed(inlineOmpRegionCleanup(
+ reductionRegions, privateReductionVariables, moduleTranslation,
+ builder, "omp.reduction.cleanup")))
+ return failure();
+ } else {
+ // Process the reductions if required.
+ if (failed(createReductionsAndCleanup(
+ wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
+ privateReductionVariables, i...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/165788
More information about the Mlir-commits
mailing list