[Openmp-commits] [flang] [mlir] [openmp] [MLIR][OpenMP] Add scan reduction lowering to llvm (PR #167031)
Anchu Rajendran S via Openmp-commits
openmp-commits at lists.llvm.org
Sun Jan 11 12:19:20 PST 2026
https://github.com/anchuraj updated https://github.com/llvm/llvm-project/pull/167031
>From 18e9c5c25abe6b41341fc2fa34233a6747452888 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Fri, 7 Nov 2025 16:33:30 -0600
Subject: [PATCH 1/5] [MLIR][OpenMP] Add scan reduction lowering to llvm
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 50 ++-
.../Lower/OpenMP/Todo/nested-wsloop-scan.f90 | 34 ++
.../OpenMP/Todo/wsloop-scan-collapse.f90 | 29 ++
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 383 ++++++++++++++----
.../Target/LLVMIR/openmp-reduction-scan.mlir | 130 ++++++
mlir/test/Target/LLVMIR/openmp-todo.mlir | 69 +++-
openmp/runtime/test/scan/scan.f90 | 25 ++
7 files changed, 634 insertions(+), 86 deletions(-)
create mode 100644 flang/test/Lower/OpenMP/Todo/nested-wsloop-scan.f90
create mode 100644 flang/test/Lower/OpenMP/Todo/wsloop-scan-collapse.f90
create mode 100644 mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
create mode 100644 openmp/runtime/test/scan/scan.f90
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 4381d1e9064cf..c4646f9d5c8de 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2322,12 +2322,52 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
static mlir::omp::ScanOp
genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx, mlir::Location loc,
- const ConstructQueue &queue, ConstructQueue::const_iterator item) {
+ semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+ mlir::Location loc, const ConstructQueue &queue,
+ ConstructQueue::const_iterator item) {
mlir::omp::ScanOperands clauseOps;
genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
- return mlir::omp::ScanOp::create(converter.getFirOpBuilder(),
- converter.getCurrentLocation(), clauseOps);
+ mlir::omp::ScanOp scanOp = mlir::omp::ScanOp::create(
+ converter.getFirOpBuilder(), converter.getCurrentLocation(), clauseOps);
+
+ /// Scan redution is not implemented with nested workshare loops, linear
+ /// clause, tiling
+ mlir::omp::LoopNestOp loopNestOp =
+ scanOp->getParentOfType<mlir::omp::LoopNestOp>();
+ mlir::omp::WsloopOp wsLoopOp = scanOp->getParentOfType<mlir::omp::WsloopOp>();
+ bool isNested =
+ (loopNestOp.getNumLoops() > 1) ||
+ (wsLoopOp && (wsLoopOp->getParentOfType<mlir::omp::WsloopOp>()));
+ if (isNested)
+ TODO(loc, "Scan directive inside nested workshare loops");
+ if (wsLoopOp && !wsLoopOp.getLinearVars().empty())
+ TODO(loc, "Scan directive with linear clause");
+ if (loopNestOp.getTileSizes())
+ TODO(loc, "Scan directive with loop tiling");
+
+ // All loop indices should be loaded after the scan construct as otherwise,
+ // it would result in using the index variable across scan directive.
+ // (`Intra-iteration dependences from a statement in the structured
+ // block sequence that precede a scan directive to a statement in the
+ // structured block sequence that follows a scan directive must not exist,
+ // except for dependences for the list items specified in an inclusive or
+ // exclusive clause.`).
+ // TODO: Nested loops are not handled.
+ mlir::Region ®ion = loopNestOp->getRegion(0);
+ mlir::Value indexVal = fir::getBase(region.getArgument(0));
+ lower::pft::Evaluation *doConstructEval = eval.parentConstruct;
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ lower::pft::Evaluation *doLoop = &doConstructEval->getFirstNestedEvaluation();
+ auto *doStmt = doLoop->getIf<parser::NonLabelDoStmt>();
+ assert(doStmt && "Expected do loop to be in the nested evaluation");
+ const auto &loopControl =
+ std::get<std::optional<parser::LoopControl>>(doStmt->t);
+ const parser::LoopControl::Bounds *bounds =
+ std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
+ mlir::Operation *storeOp =
+ setLoopVar(converter, loc, indexVal, bounds->name.thing.symbol);
+ firOpBuilder.setInsertionPointAfter(storeOp);
+ return scanOp;
}
static mlir::omp::SectionsOp
@@ -3509,7 +3549,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
loc, queue, item);
break;
case llvm::omp::Directive::OMPD_scan:
- newOp = genScanOp(converter, symTable, semaCtx, loc, queue, item);
+ newOp = genScanOp(converter, symTable, semaCtx, eval, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_section:
llvm_unreachable("genOMPDispatch: OMPD_section");
diff --git a/flang/test/Lower/OpenMP/Todo/nested-wsloop-scan.f90 b/flang/test/Lower/OpenMP/Todo/nested-wsloop-scan.f90
new file mode 100644
index 0000000000000..414e2ef2d5aed
--- /dev/null
+++ b/flang/test/Lower/OpenMP/Todo/nested-wsloop-scan.f90
@@ -0,0 +1,34 @@
+! Tests scan reduction behavior when used in nested workshare loops
+
+! RUN: %not_todo_cmd %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+program nested_scan_example
+ implicit none
+ integer, parameter :: n = 4, m = 5
+ integer :: a(n, m), b(n, m)
+ integer :: i, j
+ integer :: row_sum, col_sum
+
+ do i = 1, n
+ do j = 1, m
+ a(i, j) = i + j
+ end do
+ end do
+
+ !$omp parallel do reduction(inscan, +: row_sum) private(col_sum, j)
+ do i = 1, n
+ row_sum = row_sum + i
+ !$omp scan inclusive(row_sum)
+
+ col_sum = 0
+ !$omp parallel do reduction(inscan, +: col_sum)
+ do j = 1, m
+ col_sum = col_sum + a(i, j)
+ !CHECK: not yet implemented: Scan directive inside nested workshare loops
+ !$omp scan inclusive(col_sum)
+ b(i, j) = col_sum + row_sum
+ end do
+ !$omp end parallel do
+ end do
+ !$omp end parallel do
+end program nested_scan_example
diff --git a/flang/test/Lower/OpenMP/Todo/wsloop-scan-collapse.f90 b/flang/test/Lower/OpenMP/Todo/wsloop-scan-collapse.f90
new file mode 100644
index 0000000000000..b8e6e831884ab
--- /dev/null
+++ b/flang/test/Lower/OpenMP/Todo/wsloop-scan-collapse.f90
@@ -0,0 +1,29 @@
+! Tests scan reduction behavior when used in nested workshare loops
+
+! RUN: %not_todo_cmd %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+program nested_loop_example
+ implicit none
+ integer :: i, j, x
+ integer, parameter :: N = 100, M = 200
+ real :: A(N, M), B(N, M)
+ x = 0
+
+ do i = 1, N
+ do j = 1, M
+ A(i, j) = i * j
+ end do
+ end do
+
+ !$omp parallel do collapse(2) reduction(inscan, +:x)
+ do i = 1, N
+ do j = 1, M
+ x = x + A(i,j)
+ !CHECK: not yet implemented: Scan directive inside nested workshare loops
+ !$omp scan inclusive(x)
+ B(i, j) = x
+ end do
+ end do
+ !$omp end parallel do
+
+end program nested_loop_example
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index aad46ba094f7b..678d6934108f7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -37,6 +37,7 @@
#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
+#include <cassert>
#include <cstdint>
#include <iterator>
#include <numeric>
@@ -79,6 +80,22 @@ class OpenMPAllocaStackFrame
llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
};
+/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
+/// insertion points for allocas of parent of the current parallel region. The
+/// insertion point is used to allocate variables to be shared by the threads
+/// executing the parallel region. Lowering of scan reduction requires declaring
+/// shared pointers to the temporary buffer to perform scan reduction.
+class OpenMPParallelAllocaStackFrame
+ : public StateStackFrameBase<OpenMPParallelAllocaStackFrame> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPParallelAllocaStackFrame)
+
+ explicit OpenMPParallelAllocaStackFrame(
+ llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
+ : allocaInsertPoint(allocaIP) {}
+ llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
+};
+
/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
/// collapsed canonical loop information corresponding to an \c omp.loop_nest
/// operation.
@@ -86,7 +103,13 @@ class OpenMPLoopInfoStackFrame
: public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
- llvm::CanonicalLoopInfo *loopInfo = nullptr;
+ /// For constructs like scan, one LoopInfo frame can contain multiple
+ /// Canonical Loops as a single openmpLoopNestOp will be split into input
+ /// loop and scan loop.
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
+ llvm::ScanInfo *scanInfo;
+ llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+ new llvm::DenseMap<llvm::Value *, llvm::Type *>();
};
/// Custom error class to signal translation errors that don't need reporting,
@@ -349,6 +372,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (!op.getDependVars().empty() || op.getDependKinds())
result = todo("depend");
};
+ auto checkExclusive = [&todo](auto op, LogicalResult &result) {
+ if (!op.getExclusiveVars().empty())
+ result = todo("exclusive");
+ };
auto checkHint = [](auto op, LogicalResult &) {
if (op.getHint())
op.emitWarning("hint clause discarded");
@@ -383,9 +410,14 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (!op.getReductionVars().empty() || op.getReductionByref() ||
op.getReductionSyms())
result = todo("reduction");
- if (op.getReductionMod() &&
- op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
- result = todo("reduction with modifier");
+ if (op.getReductionMod()) {
+ if (isa<omp::WsloopOp>(op)) {
+ if (op.getReductionMod().value() == omp::ReductionModifier::task)
+ result = todo("reduction with task modifier");
+ } else {
+ result = todo("reduction with modifier");
+ }
+ }
};
auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
@@ -408,6 +440,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkOrder(op, result);
})
.Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
+ .Case([&](omp::ScanOp op) { checkExclusive(op, result); })
.Case([&](omp::SectionsOp op) {
checkAllocate(op, result);
checkPrivate(op, result);
@@ -536,15 +569,59 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
/// Find the loop information structure for the loop nest being translated. It
/// will return a `null` value unless called from the translation function for
/// a loop wrapper operation after successfully translating its body.
-static llvm::CanonicalLoopInfo *
-findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
- llvm::CanonicalLoopInfo *loopInfo = nullptr;
+static SmallVector<llvm::CanonicalLoopInfo *>
+findCurrentLoopInfos(LLVM::ModuleTranslation &moduleTranslation) {
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
+ moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+ [&](OpenMPLoopInfoStackFrame &frame) {
+ loopInfos = frame.loopInfos;
+ return WalkResult::interrupt();
+ });
+ return loopInfos;
+}
+
+// LoopFrame stores the scaninfo which is used for scan reduction.
+// Upon encountering an `inscan` reduction modifier, `scanInfoInitialize`
+// initializes the ScanInfo and is used when scan directive is encountered
+// in the body of the loop nest.
+static llvm::ScanInfo *
+findScanInfo(LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::ScanInfo *scanInfo;
+ moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+ [&](OpenMPLoopInfoStackFrame &frame) {
+ scanInfo = frame.scanInfo;
+ return WalkResult::interrupt();
+ });
+ return scanInfo;
+}
+
+// The types of reduction vars are used for lowering scan directive which
+// appears in the body of the loop. The types are stored in loop frame when
+// reduction clause is encountered and is used when scan directive is
+// encountered.
+static llvm::DenseMap<llvm::Value *, llvm::Type *> *
+findReductionVarTypes(LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType = nullptr;
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
[&](OpenMPLoopInfoStackFrame &frame) {
- loopInfo = frame.loopInfo;
+ reductionVarToType = frame.reductionVarToType;
return WalkResult::interrupt();
});
- return loopInfo;
+ return reductionVarToType;
+}
+
+// Scan reduction requires a shared buffer to be allocated to perform reduction.
+// ParallelAllocaStackFrame holds the allocaIP where shared allocation can be
+// done.
+static llvm::OpenMPIRBuilder::InsertPointTy
+findParallelAllocaIP(LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder::InsertPointTy parallelAllocaIP;
+ moduleTranslation.stackWalk<OpenMPParallelAllocaStackFrame>(
+ [&](OpenMPParallelAllocaStackFrame &frame) {
+ parallelAllocaIP = frame.allocaInsertPoint;
+ return WalkResult::interrupt();
+ });
+ return parallelAllocaIP;
}
/// Converts the given region that appears within an OpenMP dialect operation to
@@ -1301,11 +1378,17 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
for (auto [data, addr] : deferredStores)
builder.CreateStore(data, addr);
+ llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+ findReductionVarTypes(moduleTranslation);
// Before the loop, store the initial values of reductions into reduction
// variables. Although this could be done after allocas, we don't want to mess
// up with the alloca insertion point.
for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
SmallVector<llvm::Value *, 1> phis;
+ llvm::Type *reductionType =
+ moduleTranslation.convertType(reductionDecls[i].getType());
+ if (reductionVarToType != nullptr)
+ (*reductionVarToType)[privateReductionVariables[i]] = reductionType;
// map block argument to initializer region
mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
@@ -1381,6 +1464,8 @@ static void collectReductionInfo(
// Collect the reduction information.
reductionInfos.reserve(numReductions);
+ llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+ findReductionVarTypes(moduleTranslation);
for (unsigned i = 0; i < numReductions; ++i) {
llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
if (owningAtomicReductionGens[i])
@@ -1397,9 +1482,12 @@ static void collectReductionInfo(
return mlir::WalkResult::advance();
});
+ llvm::Type *reductionType =
+ moduleTranslation.convertType(reductionDecls[i].getType());
+ if (reductionVarToType != nullptr)
+ (*reductionVarToType)[privateReductionVariables[i]] = reductionType;
reductionInfos.push_back(
- {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
- privateReductionVariables[i],
+ {reductionType, variable, privateReductionVariables[i],
/*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
owningReductionGens[i],
/*ReductionGenClang=*/nullptr, atomicGen,
@@ -2645,6 +2733,9 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
bool isSimd = wsloopOp.getScheduleSimd();
bool loopNeedsBarrier = !wsloopOp.getNowait();
+ bool isInScanRegion =
+ wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
+ mlir::omp::ReductionModifier::inscan);
// The only legal way for the direct parent to be omp.distribute is that this
// represents 'distribute parallel do'. Otherwise, this is a regular
@@ -2682,21 +2773,78 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(regionBlock, opInst)))
return failure();
- llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+ findCurrentLoopInfos(moduleTranslation);
+
+ const auto &&wsloopCodeGen = [&](llvm::CanonicalLoopInfo *loopInfo,
+ bool noLoopMode, bool inputScanLoop) {
+ // Emit Initialization and Update IR for linear variables
+ if (!isInScanRegion && !wsloopOp.getLinearVars().empty()) {
+ linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+ loopInfo->getPreheader());
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+ moduleTranslation.getOpenMPBuilder()->createBarrier(
+ builder.saveIP(), llvm::omp::OMPD_barrier);
+ if (failed(handleError(afterBarrierIP, *loopOp)))
+ return failure();
+ builder.restoreIP(*afterBarrierIP);
+ linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+ loopInfo->getIndVar());
+ linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
+ }
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+ ompBuilder->applyWorkshareLoop(
+ ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+ convertToScheduleKind(schedule), chunk, isSimd,
+ scheduleMod == omp::ScheduleModifier::monotonic,
+ scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+ workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk));
+
+ if (failed(handleError(wsloopIP, opInst)))
+ return failure();
+
+ // Emit finalization and in-place rewrites for linear vars.
+ if (!isInScanRegion && !wsloopOp.getLinearVars().empty()) {
+ llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+ if (loopInfo->getLastIter())
+ return failure();
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+ linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
+ loopInfo->getLastIter());
+ if (failed(handleError(afterBarrierIP, *loopOp)))
+ return failure();
+ for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
+ linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
+ index);
+ builder.restoreIP(oldIP);
+ }
+ if (!inputScanLoop || !isInScanRegion)
+ popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
- // Emit Initialization and Update IR for linear variables
- if (!wsloopOp.getLinearVars().empty()) {
- linearClauseProcessor.initLinearVar(builder, moduleTranslation,
- loopInfo->getPreheader());
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
- moduleTranslation.getOpenMPBuilder()->createBarrier(
- builder.saveIP(), llvm::omp::OMPD_barrier);
- if (failed(handleError(afterBarrierIP, *loopOp)))
+ return llvm::success();
+ };
+
+ if (isInScanRegion) {
+ auto inputLoopFinishIp = loopInfos.front()->getAfterIP();
+ builder.restoreIP(inputLoopFinishIp);
+ SmallVector<OwningReductionGen> owningReductionGens;
+ SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+ collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
+ owningReductionGens, owningAtomicReductionGens,
+ privateReductionVariables, reductionInfos);
+ llvm::BasicBlock *cont = splitBB(builder, false, "omp.scan.loop.cont");
+ llvm::ScanInfo *scanInfo = findScanInfo(moduleTranslation);
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP =
+ ompBuilder->emitScanReduction(builder.saveIP(), reductionInfos,
+ scanInfo);
+ if (failed(handleError(redIP, opInst)))
return failure();
- builder.restoreIP(*afterBarrierIP);
- linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
- loopInfo->getIndVar());
- linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
+
+ builder.restoreIP(*redIP);
+ builder.CreateBr(cont);
+>>>>>>> e715dc85eb45 ([MLIR][OpenMP] Add scan reduction lowering to llvm)
}
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
@@ -2721,42 +2869,37 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
}
}
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
- ompBuilder->applyWorkshareLoop(
- ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
- convertToScheduleKind(schedule), chunk, isSimd,
- scheduleMod == omp::ScheduleModifier::monotonic,
- scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
- workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
-
- if (failed(handleError(wsloopIP, opInst)))
- return failure();
-
- // Emit finalization and in-place rewrites for linear vars.
- if (!wsloopOp.getLinearVars().empty()) {
- llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
- assert(loopInfo->getLastIter() &&
- "`lastiter` in CanonicalLoopInfo is nullptr");
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
- linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
- loopInfo->getLastIter());
- if (failed(handleError(afterBarrierIP, *loopOp)))
+ if (isInScanRegion)
+ assert(wsloopOp.getLinearVars().empty() &&
+ "Linear clause support is not enabled with scan reduction");
+ // For Scan loops input loop need not pop cancellation CB and hence, it is set
+ // false for the first loop
+ bool inputScanLoop = isInScanRegion;
+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+ // TODO: Linear clause support needs to be enabled for scan reduction.
+ if (failed(wsloopCodeGen(loopInfo, noLoopMode, inputScanLoop)))
return failure();
- for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
- linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
- index);
- builder.restoreIP(oldIP);
+ inputScanLoop = false;
}
- // Set the correct branch target for task cancellation
- popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
-
- // Process the reductions if required.
- if (failed(createReductionsAndCleanup(
- wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
- privateReductionVariables, isByRef, wsloopOp.getNowait(),
- /*isTeamsReduction=*/false)))
- return failure();
+ if (isInScanRegion) {
+ SmallVector<Region *> reductionRegions;
+ llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
+ [](omp::DeclareReductionOp reductionDecl) {
+ return &reductionDecl.getCleanupRegion();
+ });
+ if (failed(inlineOmpRegionCleanup(
+ reductionRegions, privateReductionVariables, moduleTranslation,
+ builder, "omp.reduction.cleanup")))
+ return failure();
+ } else {
+ // Process the reductions if required.
+ if (failed(createReductionsAndCleanup(
+ wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
+ privateReductionVariables, isByRef, wsloopOp.getNowait(),
+ /*isTeamsReduction=*/false)))
+ return failure();
+ }
return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
privateVarsInfo.llvmVars,
@@ -2939,6 +3082,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ LLVM::ModuleTranslation::SaveStack<OpenMPParallelAllocaStackFrame> frame(
+ moduleTranslation, allocaIP);
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
@@ -3088,7 +3233,11 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(regionBlock, opInst)))
return failure();
- llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+ findCurrentLoopInfos(moduleTranslation);
+ llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation).front();
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+ findCurrentLoopInfos(moduleTranslation);
// Emit Initialization for linear variables
if (simdOp.getLinearVars().size()) {
linearClauseProcessor.initLinearVar(builder, moduleTranslation,
@@ -3098,12 +3247,13 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
loopInfo->getIndVar());
}
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
-
- ompBuilder->applySimd(loopInfo, alignedVars,
- simdOp.getIfExpr()
- ? moduleTranslation.lookupValue(simdOp.getIfExpr())
- : nullptr,
- order, simdlen, safelen);
+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+ ompBuilder->applySimd(
+ loopInfo, alignedVars,
+ simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+ : nullptr,
+ order, simdlen, safelen);
+ }
linearClauseProcessor.emitStoresForLinearVar(builder);
for (size_t index = 0; index < simdOp.getLinearVars().size(); index++)
@@ -3159,6 +3309,40 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
privateVarsInfo.privatizers);
}
+static LogicalResult
+convertOmpScan(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ if (failed(checkImplementationStatus(opInst)))
+ return failure();
+ auto scanOp = cast<omp::ScanOp>(opInst);
+ bool isInclusive = scanOp.hasInclusiveVars();
+ SmallVector<llvm::Value *> llvmScanVars;
+ SmallVector<llvm::Type *> llvmScanVarsType;
+ mlir::OperandRange mlirScanVars = scanOp.getInclusiveVars();
+ if (!isInclusive)
+ mlirScanVars = scanOp.getExclusiveVars();
+
+ llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+ findReductionVarTypes(moduleTranslation);
+ for (auto val : mlirScanVars) {
+ llvm::Value *llvmVal = moduleTranslation.lookupValue(val);
+ llvmScanVars.push_back(llvmVal);
+ llvmScanVarsType.push_back((*reductionVarToType)[llvmVal]);
+ }
+ llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
+ findParallelAllocaIP(moduleTranslation);
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::ScanInfo *scanInfo = findScanInfo(moduleTranslation);
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
+ moduleTranslation.getOpenMPBuilder()->createScan(
+ ompLoc, allocaIP, llvmScanVars, llvmScanVarsType, isInclusive,
+ scanInfo);
+ if (failed(handleError(afterIP, opInst)))
+ return failure();
+ builder.restoreIP(*afterIP);
+ return success();
+}
+
/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -3220,14 +3404,47 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
computeIP = loopInfos.front()->getPreheaderIP();
}
+ bool isInScanRegion = false;
+ if (auto wsloopOp = loopOp->getParentOfType<omp::WsloopOp>())
+ isInScanRegion =
+ wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
+ mlir::omp::ReductionModifier::inscan);
+ if (isInScanRegion) {
+ llvm::Expected<llvm::ScanInfo *> res = ompBuilder->scanInfoInitialize();
+ if (failed(handleError(res, *loopOp)))
+ return failure();
+ llvm::ScanInfo *scanInfo = res.get();
+ moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+ [&](OpenMPLoopInfoStackFrame &frame) {
+ frame.scanInfo = scanInfo;
+ return WalkResult::interrupt();
+ });
+ llvm::Expected<llvm::SmallVector<llvm::CanonicalLoopInfo *>> loopResults =
+ ompBuilder->createCanonicalScanLoops(
+ loc, bodyGen, lowerBound, upperBound, step,
+ /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP, "loop",
+ scanInfo);
+
+ if (failed(handleError(loopResults, *loopOp)))
+ return failure();
+ llvm::CanonicalLoopInfo *inputLoop = loopResults.get().front();
+ llvm::CanonicalLoopInfo *scanLoop = loopResults.get().back();
+ moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+ [&](OpenMPLoopInfoStackFrame &frame) {
+ frame.loopInfos.push_back(inputLoop);
+ frame.loopInfos.push_back(scanLoop);
+ return WalkResult::interrupt();
+ });
+ builder.restoreIP(scanLoop->getAfterIP());
+ // TODO: tiling and collapse are not yet implemented for scan reduction
+ return success();
+ }
llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
ompBuilder->createCanonicalLoop(
loc, bodyGen, lowerBound, upperBound, step,
/*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
-
if (failed(handleError(loopResult, *loopOp)))
return failure();
-
loopInfos.push_back(*loopResult);
}
@@ -3270,7 +3487,7 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
assert(newTopLoopInfo && "New top loop information is missing");
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
[&](OpenMPLoopInfoStackFrame &frame) {
- frame.loopInfo = newTopLoopInfo;
+ frame.loopInfos.push_back(newTopLoopInfo);
return WalkResult::interrupt();
});
@@ -5388,18 +5605,21 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
bool loopNeedsBarrier = false;
llvm::Value *chunk = moduleTranslation.lookupValue(
distributeOp.getDistScheduleChunkSize());
- llvm::CanonicalLoopInfo *loopInfo =
- findCurrentLoopInfo(moduleTranslation);
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
- ompBuilder->applyWorkshareLoop(
- ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
- convertToScheduleKind(schedule), chunk, isSimd,
- scheduleMod == omp::ScheduleModifier::monotonic,
- scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
- workshareLoopType, false, hasDistSchedule, chunk);
-
- if (!wsloopIP)
- return wsloopIP.takeError();
+
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+ findCurrentLoopInfos(moduleTranslation);
+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+ ompBuilder->applyWorkshareLoop(
+ ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+ convertToScheduleKind(schedule), chunk, isSimd,
+ scheduleMod == omp::ScheduleModifier::monotonic,
+ scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+ workshareLoopType);
+
+ if (!wsloopIP)
+ return wsloopIP.takeError();
+ }
}
if (failed(cleanupPrivateVars(builder, moduleTranslation,
distributeOp.getLoc(), privVarsInfo.llvmVars,
@@ -6604,6 +6824,11 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
.Case([&](omp::WsloopOp) {
return convertOmpWsloop(*op, builder, moduleTranslation);
})
+ .Case([&](omp::ScanOp) {
+ if (failed(checkImplementationStatus(*op)))
+ return failure();
+ return convertOmpScan(*op, builder, moduleTranslation);
+ })
.Case([&](omp::SimdOp) {
return convertOmpSimd(*op, builder, moduleTranslation);
})
diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
new file mode 100644
index 0000000000000..ed04a069b998f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
@@ -0,0 +1,130 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+omp.declare_reduction @add_reduction_i32 : i32 init {
+^bb0(%arg0: i32):
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ omp.yield(%0 : i32)
+} combiner {
+^bb0(%arg0: i32, %arg1: i32):
+ %0 = llvm.add %arg0, %arg1 : i32
+ omp.yield(%0 : i32)
+}
+// CHECK-LABEL: @scan_reduction
+llvm.func @scan_reduction() {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x i32 {bindc_name = "z"} : (i64) -> !llvm.ptr
+ %2 = llvm.mlir.constant(1 : i64) : i64
+ %3 = llvm.alloca %2 x i32 {bindc_name = "y"} : (i64) -> !llvm.ptr
+ %4 = llvm.mlir.constant(1 : i64) : i64
+ %5 = llvm.alloca %4 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+ %6 = llvm.mlir.constant(1 : i64) : i64
+ %7 = llvm.alloca %6 x i32 {bindc_name = "k"} : (i64) -> !llvm.ptr
+ %8 = llvm.mlir.constant(0 : index) : i64
+ %9 = llvm.mlir.constant(1 : index) : i64
+ %10 = llvm.mlir.constant(100 : i32) : i32
+ %11 = llvm.mlir.constant(1 : i32) : i32
+ %12 = llvm.mlir.constant(0 : i32) : i32
+ %13 = llvm.mlir.constant(100 : index) : i64
+ %14 = llvm.mlir.addressof @_QFEa : !llvm.ptr
+ %15 = llvm.mlir.addressof @_QFEb : !llvm.ptr
+ omp.parallel {
+ %37 = llvm.mlir.constant(1 : i64) : i64
+ %38 = llvm.alloca %37 x i32 {bindc_name = "k", pinned} : (i64) -> !llvm.ptr
+ %39 = llvm.mlir.constant(1 : i64) : i64
+ omp.wsloop reduction(mod: inscan, @add_reduction_i32 %5 -> %arg0 : !llvm.ptr) {
+ omp.loop_nest (%arg1) : i32 = (%11) to (%10) inclusive step (%11) {
+ llvm.store %arg1, %38 : i32, !llvm.ptr
+ %40 = llvm.load %arg0 : !llvm.ptr -> i32
+ %41 = llvm.load %38 : !llvm.ptr -> i32
+ %42 = llvm.sext %41 : i32 to i64
+ %50 = llvm.getelementptr %14[%42] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+ %51 = llvm.load %50 : !llvm.ptr -> i32
+ %52 = llvm.add %40, %51 : i32
+ llvm.store %52, %arg0 : i32, !llvm.ptr
+ omp.scan inclusive(%arg0 : !llvm.ptr)
+ llvm.store %arg1, %38 : i32, !llvm.ptr
+ %53 = llvm.load %arg0 : !llvm.ptr -> i32
+ %54 = llvm.load %38 : !llvm.ptr -> i32
+ %55 = llvm.sext %54 : i32 to i64
+ %63 = llvm.getelementptr %15[%55] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+ llvm.store %53, %63 : i32, !llvm.ptr
+ omp.yield
+ }
+ }
+ omp.terminator
+ }
+ llvm.return
+}
+llvm.mlir.global internal @_QFEa() {addr_space = 0 : i32} : !llvm.array<100 x i32> {
+ %0 = llvm.mlir.zero : !llvm.array<100 x i32>
+ llvm.return %0 : !llvm.array<100 x i32>
+}
+llvm.mlir.global internal @_QFEb() {addr_space = 0 : i32} : !llvm.array<100 x i32> {
+ %0 = llvm.mlir.zero : !llvm.array<100 x i32>
+ llvm.return %0 : !llvm.array<100 x i32>
+}
+llvm.mlir.global internal constant @_QFECn() {addr_space = 0 : i32} : i32 {
+ %0 = llvm.mlir.constant(100 : i32) : i32
+ llvm.return %0 : i32
+}
+//CHECK: %vla = alloca ptr, align 8
+//CHECK: omp_parallel
+//CHECK: store ptr %vla, ptr %gep_vla, align 8
+//CHECK: @__kmpc_fork_call
+//CHECK: void @scan_reduction..omp_par
+//CHECK: %[[BUFF_PTR:.+]] = load ptr, ptr %gep_vla
+//CHECK: @__kmpc_masked
+//CHECK: @__kmpc_barrier
+//CHECK: @__kmpc_masked
+//CHECK: @__kmpc_barrier
+//CHECK: omp.scan.loop.cont:
+//CHECK: @__kmpc_masked
+//CHECK: @__kmpc_barrier
+//CHECK: %[[FREE_VAR:.+]] = load ptr, ptr %[[BUFF_PTR]], align 8
+//CHECK: %[[ARRLAST:.+]] = getelementptr inbounds i32, ptr %[[FREE_VAR]], i32 100
+//CHECK: %[[RES:.+]] = load i32, ptr %[[ARRLAST]], align 4
+//CHECK: store i32 %[[RES]], ptr %loadgep{{.*}}, align 4
+//CHECK: tail call void @free(ptr %[[FREE_VAR]])
+//CHECK: @__kmpc_end_masked
+//CHECK: omp.inscan.dispatch{{.*}}: ; preds = %omp_loop.body{{.*}}
+//CHECK: %[[BUFFVAR:.+]] = load ptr, ptr %[[BUFF_PTR]], align 8
+//CHECK: %[[arrayOffset1:.+]] = getelementptr inbounds i32, ptr %[[BUFFVAR]], i32 %{{.*}}
+//CHECK: %[[BUFFVAL1:.+]] = load i32, ptr %[[arrayOffset1]], align 4
+//CHECK: store i32 %[[BUFFVAL1]], ptr %{{.*}}, align 4
+//CHECK: %[[LOG:.+]] = call double @llvm.log2.f64(double 1.000000e+02) #0
+//CHECK: %[[CEIL:.+]] = call double @llvm.ceil.f64(double %[[LOG]]) #0
+//CHECK: %[[UB:.+]] = fptoui double %[[CEIL]] to i32
+//CHECK: br label %omp.outer.log.scan.body
+//CHECK: omp.outer.log.scan.body:
+//CHECK: %[[K:.+]] = phi i32 [ 0, %{{.*}} ], [ %[[NEXTK:.+]], %omp.inner.log.scan.exit ]
+//CHECK: %[[I:.+]] = phi i32 [ 1, %{{.*}} ], [ %[[NEXTI:.+]], %omp.inner.log.scan.exit ]
+//CHECK: %[[CMP1:.+]] = icmp uge i32 99, %[[I]]
+//CHECK: br i1 %[[CMP1]], label %omp.inner.log.scan.body, label %omp.inner.log.scan.exit
+//CHECK: omp.inner.log.scan.exit: ; preds = %omp.inner.log.scan.body, %omp.outer.log.scan.body
+//CHECK: %[[NEXTK]] = add nuw i32 %[[K]], 1
+//CHECK: %[[NEXTI]] = shl nuw i32 %[[I]], 1
+//CHECK: %[[CMP2:.+]] = icmp ne i32 %[[NEXTK]], %[[UB]]
+//CHECK: br i1 %[[CMP2]], label %omp.outer.log.scan.body, label %omp.outer.log.scan.exit
+//CHECK: omp.outer.log.scan.exit: ; preds = %omp.inner.log.scan.exit
+//CHECK: @__kmpc_end_masked
+//CHECK: omp.inner.log.scan.body: ; preds = %omp.inner.log.scan.body, %omp.outer.log.scan.body
+//CHECK: %[[CNT:.+]] = phi i32 [ 99, %omp.outer.log.scan.body ], [ %[[CNTNXT:.+]], %omp.inner.log.scan.body ]
+//CHECK: %[[BUFF:.+]] = load ptr, ptr %[[BUFF_PTR]]
+//CHECK: %[[IND1:.+]] = add i32 %[[CNT]], 1
+//CHECK: %[[IND1PTR:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 %[[IND1]]
+//CHECK: %[[IND2:.+]] = sub nuw i32 %[[IND1]], %[[I]]
+//CHECK: %[[IND2PTR:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 %[[IND2]]
+//CHECK: %[[IND1VAL:.+]] = load i32, ptr %[[IND1PTR]], align 4
+//CHECK: %[[IND2VAL:.+]] = load i32, ptr %[[IND2PTR]], align 4
+//CHECK: %[[REDVAL:.+]] = add i32 %[[IND1VAL]], %[[IND2VAL]]
+//CHECK: store i32 %[[REDVAL]], ptr %[[IND1PTR]], align 4
+//CHECK: %[[CNTNXT]] = sub nuw i32 %[[CNT]], 1
+//CHECK: %[[CMP3:.+]] = icmp uge i32 %[[CNTNXT]], %[[I]]
+//CHECK: br i1 %[[CMP3]], label %omp.inner.log.scan.body, label %omp.inner.log.scan.exit
+//CHECK: omp.inscan.dispatch: ; preds = %omp_loop.body
+//CHECK: br i1 true, label %omp.before.scan.bb, label %omp.after.scan.bb
+//CHECK: omp.loop_nest.region: ; preds = %omp.before.scan.bb
+//CHECK: %[[BUFFER:.+]] = load ptr, ptr %loadgep_vla, align 8
+//CHECK: %[[ARRAYOFFSET2:.+]] = getelementptr inbounds i32, ptr %[[BUFFER]], i32 %{{.*}}
+//CHECK-NEXT: %[[REDPRIVVAL:.+]] = load i32, ptr %{{.*}}, align 4
+//CHECK: store i32 %[[REDPRIVVAL]], ptr %[[ARRAYOFFSET2]], align 4
+//CHECK: br label %omp.scan.loop.exit
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index d4cc9e215de1d..81119ef700c06 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -101,6 +101,68 @@ llvm.func @sections_private(%x : !llvm.ptr) {
}
+// -----
+
+omp.declare_reduction @add_f32 : f32
+init {
+^bb0(%arg: f32):
+ %0 = llvm.mlir.constant(0.0 : f32) : f32
+ omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+ %1 = llvm.fadd %arg0, %arg1 : f32
+ omp.yield (%1 : f32)
+}
+atomic {
+^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
+ %2 = llvm.load %arg3 : !llvm.ptr -> f32
+ llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
+ omp.yield
+}
+llvm.func @task_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
+ // expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}
+ // expected-error at below {{not yet implemented: Unhandled clause reduction with task modifier in omp.wsloop operation}}
+ omp.wsloop reduction(mod:task, @add_f32 %x -> %prv : !llvm.ptr) {
+ omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ }
+ llvm.return
+}
+
+
+// -----
+
+omp.declare_reduction @add_f32 : f32
+init {
+^bb0(%arg: f32):
+ %0 = llvm.mlir.constant(0.0 : f32) : f32
+ omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+ %1 = llvm.fadd %arg0, %arg1 : f32
+ omp.yield (%1 : f32)
+}
+atomic {
+^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
+ %2 = llvm.load %arg3 : !llvm.ptr -> f32
+ llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
+ omp.yield
+}
+llvm.func @simd_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
+ // expected-error at below {{not yet implemented: Unhandled clause reduction with modifier in omp.simd operation}}
+ // expected-error at below {{LLVM Translation failed for operation: omp.simd}}
+ omp.simd reduction(mod:inscan, @add_f32 %x -> %prv : !llvm.ptr) {
+ omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+ omp.scan inclusive(%prv : !llvm.ptr)
+ omp.yield
+ }
+ }
+ llvm.return
+}
+
// -----
omp.declare_reduction @add_f32 : f32
@@ -121,17 +183,20 @@ atomic {
omp.yield
}
llvm.func @scan_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
- // expected-error at below {{not yet implemented: Unhandled clause reduction with modifier in omp.wsloop operation}}
// expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}
omp.wsloop reduction(mod:inscan, @add_f32 %x -> %prv : !llvm.ptr) {
+ // expected-error at below {{LLVM Translation failed for operation: omp.loop_nest}}
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
- omp.scan inclusive(%prv : !llvm.ptr)
+ // expected-error at below {{not yet implemented: Unhandled clause exclusive in omp.scan operation}}
+ // expected-error at below {{LLVM Translation failed for operation: omp.scan}}
+ omp.scan exclusive(%prv : !llvm.ptr)
omp.yield
}
}
llvm.return
}
+
// -----
llvm.func @single_allocate(%x : !llvm.ptr) {
diff --git a/openmp/runtime/test/scan/scan.f90 b/openmp/runtime/test/scan/scan.f90
new file mode 100644
index 0000000000000..1bea5c7e7a481
--- /dev/null
+++ b/openmp/runtime/test/scan/scan.f90
@@ -0,0 +1,25 @@
+! RUN: %flang %flags %openmp_flags -fopenmp-version=51 %s -o %t.exe
+! RUN: %t.exe | FileCheck %s --match-full-lines
+program inclusive_scan
+ implicit none
+ integer, parameter :: n = 100
+ integer a(n), b(n)
+ integer x, k, y, z
+
+ ! initialization
+ x = 0
+ do k = 1, n
+ a(k) = k
+ end do
+
+ ! a(k) is included in the computation of producing results in b(k)
+ !$omp parallel do reduction(inscan, +: x)
+ do k = 1, n
+ x = x + a(k)
+ !$omp scan inclusive(x)
+ b(k) = x
+ end do
+
+ print *,'x =', x
+end program
+!CHECK: x = 5050
>From a421469abfb770693e93621fc066ba17a4c56ec7 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Thu, 13 Nov 2025 15:25:44 -0600
Subject: [PATCH 2/5] R2: Added unique_ptr for reduction var types
---
.../LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 678d6934108f7..778936ab90275 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -108,8 +108,10 @@ class OpenMPLoopInfoStackFrame
/// loop and scan loop.
SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
llvm::ScanInfo *scanInfo;
- llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
- new llvm::DenseMap<llvm::Value *, llvm::Type *>();
+ /// Map reduction variables to their LLVM types.
+ std::unique_ptr<llvm::DenseMap<llvm::Value *, llvm::Type *>>
+ reductionVarToType =
+ std::make_unique<llvm::DenseMap<llvm::Value *, llvm::Type *>>();
};
/// Custom error class to signal translation errors that don't need reporting,
@@ -604,7 +606,7 @@ findReductionVarTypes(LLVM::ModuleTranslation &moduleTranslation) {
llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType = nullptr;
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
[&](OpenMPLoopInfoStackFrame &frame) {
- reductionVarToType = frame.reductionVarToType;
+ reductionVarToType = frame.reductionVarToType.get();
return WalkResult::interrupt();
});
return reductionVarToType;
>From f41499be3274a8e1d95e044d67b2f8b0546a1ac1 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Wed, 17 Dec 2025 13:59:54 -0600
Subject: [PATCH 3/5] Addressing a few comments
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 10 ++++++++--
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 18 +++++++++++-------
2 files changed, 19 insertions(+), 9 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c4646f9d5c8de..97f304206dd74 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2330,11 +2330,17 @@ genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
mlir::omp::ScanOp scanOp = mlir::omp::ScanOp::create(
converter.getFirOpBuilder(), converter.getCurrentLocation(), clauseOps);
- /// Scan redution is not implemented with nested workshare loops, linear
+ /// Scan reduction is not implemented with nested workshare loops, linear
/// clause, tiling
mlir::omp::LoopNestOp loopNestOp =
scanOp->getParentOfType<mlir::omp::LoopNestOp>();
- mlir::omp::WsloopOp wsLoopOp = scanOp->getParentOfType<mlir::omp::WsloopOp>();
+ llvm::SmallVector<mlir::omp::LoopWrapperInterface> loopWrappers;
+ loopNestOp.gatherWrappers(loopWrappers);
+ mlir::Operation *loopWrapperOp = loopWrappers.front().getOperation();
+ if (llvm::isa<mlir::omp::SimdOp>(loopWrapperOp)) TODO(loc, "unsupported simd");
+ if (loopWrappers.size() > 1) TODO(loc, "unsupported composite");
+ mlir::omp::WsloopOp wsLoopOp = llvm::cast<mlir::omp::WsloopOp>(loopWrapperOp);
+ //mlir::omp::WsloopOp wsLoopOp = scanOp->getParentOfType<mlir::omp::WsloopOp>();
bool isNested =
(loopNestOp.getNumLoops() > 1) ||
(wsLoopOp && (wsLoopOp->getParentOfType<mlir::omp::WsloopOp>()));
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 778936ab90275..84eab0b31a219 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -75,9 +75,13 @@ class OpenMPAllocaStackFrame
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame)
- explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
- : allocaInsertPoint(allocaIP) {}
+ explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP, bool parallelOp = false)
+ : allocaInsertPoint(allocaIP), containsParallelOp(parallelOp) {}
llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
+ // is set to true when a parallel Op is encountered.
+ // The alloca IP of a function where a parallel Op is defined may
+ // be used for the scan directive.
+ bool containsParallelOp = false;
};
/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
@@ -2738,6 +2742,9 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
bool isInScanRegion =
wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
mlir::omp::ReductionModifier::inscan);
+ if (isInScanRegion)
+ assert(wsloopOp.getLinearVars().empty() &&
+ "Linear clause support is not enabled with scan reduction");
// The only legal way for the direct parent to be omp.distribute is that this
// represents 'distribute parallel do'. Otherwise, this is a regular
@@ -2781,7 +2788,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
const auto &&wsloopCodeGen = [&](llvm::CanonicalLoopInfo *loopInfo,
bool noLoopMode, bool inputScanLoop) {
// Emit Initialization and Update IR for linear variables
- if (!isInScanRegion && !wsloopOp.getLinearVars().empty()) {
+ if (!wsloopOp.getLinearVars().empty()) {
linearClauseProcessor.initLinearVar(builder, moduleTranslation,
loopInfo->getPreheader());
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
@@ -2807,7 +2814,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
return failure();
// Emit finalization and in-place rewrites for linear vars.
- if (!isInScanRegion && !wsloopOp.getLinearVars().empty()) {
+ if (!wsloopOp.getLinearVars().empty()) {
llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
if (loopInfo->getLastIter())
return failure();
@@ -2871,9 +2878,6 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
}
}
- if (isInScanRegion)
- assert(wsloopOp.getLinearVars().empty() &&
- "Linear clause support is not enabled with scan reduction");
// For Scan loops input loop need not pop cancellation CB and hence, it is set
// false for the first loop
bool inputScanLoop = isInScanRegion;
>From 0290fa9b8c570455edae77f5bf43f79f947d27a5 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Fri, 2 Jan 2026 17:57:17 -0600
Subject: [PATCH 4/5] adding parallel alloca ip in same stack frame
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 9 +-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 99 +++++++++++--------
.../Target/LLVMIR/openmp-reduction-scan.mlir | 15 +--
openmp/runtime/test/scan/scan.f90 | 13 +++
4 files changed, 82 insertions(+), 54 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 97f304206dd74..e77675a4c55f2 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2337,10 +2337,13 @@ genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
llvm::SmallVector<mlir::omp::LoopWrapperInterface> loopWrappers;
loopNestOp.gatherWrappers(loopWrappers);
mlir::Operation *loopWrapperOp = loopWrappers.front().getOperation();
- if (llvm::isa<mlir::omp::SimdOp>(loopWrapperOp)) TODO(loc, "unsupported simd");
- if (loopWrappers.size() > 1) TODO(loc, "unsupported composite");
+ if (llvm::isa<mlir::omp::SimdOp>(loopWrapperOp))
+ TODO(loc, "unsupported simd");
+ if (loopWrappers.size() > 1)
+ TODO(loc, "unsupported composite");
mlir::omp::WsloopOp wsLoopOp = llvm::cast<mlir::omp::WsloopOp>(loopWrapperOp);
- //mlir::omp::WsloopOp wsLoopOp = scanOp->getParentOfType<mlir::omp::WsloopOp>();
+ // mlir::omp::WsloopOp wsLoopOp =
+ // scanOp->getParentOfType<mlir::omp::WsloopOp>();
bool isNested =
(loopNestOp.getNumLoops() > 1) ||
(wsLoopOp && (wsLoopOp->getParentOfType<mlir::omp::WsloopOp>()));
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 84eab0b31a219..80cfc21364b34 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -38,6 +38,7 @@
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include <cassert>
+#include <cstddef>
#include <cstdint>
#include <iterator>
#include <numeric>
@@ -75,31 +76,16 @@ class OpenMPAllocaStackFrame
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame)
- explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP, bool parallelOp = false)
+ explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP,
+ bool parallelOp = false)
: allocaInsertPoint(allocaIP), containsParallelOp(parallelOp) {}
llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
// is set to true when a parallel Op is encountered.
- // The alloca IP of a function where a parallel Op is defined may
+ // The alloca IP of a function where a parallel Op is defined may
// be used for the scan directive.
bool containsParallelOp = false;
};
-/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
-/// insertion points for allocas of parent of the current parallel region. The
-/// insertion point is used to allocate variables to be shared by the threads
-/// executing the parallel region. Lowering of scan reduction requires declaring
-/// shared pointers to the temporary buffer to perform scan reduction.
-class OpenMPParallelAllocaStackFrame
- : public StateStackFrameBase<OpenMPParallelAllocaStackFrame> {
-public:
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPParallelAllocaStackFrame)
-
- explicit OpenMPParallelAllocaStackFrame(
- llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
- : allocaInsertPoint(allocaIP) {}
- llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
-};
-
/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
/// collapsed canonical loop information corresponding to an \c omp.loop_nest
/// operation.
@@ -111,7 +97,7 @@ class OpenMPLoopInfoStackFrame
/// Canonical Loops as a single openmpLoopNestOp will be split into input
/// loop and scan loop.
SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
- llvm::ScanInfo *scanInfo;
+ llvm::ScanInfo *scanInfo = nullptr;
/// Map reduction variables to their LLVM types.
std::unique_ptr<llvm::DenseMap<llvm::Value *, llvm::Type *>>
reductionVarToType =
@@ -420,8 +406,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (isa<omp::WsloopOp>(op)) {
if (op.getReductionMod().value() == omp::ReductionModifier::task)
result = todo("reduction with task modifier");
- } else {
- result = todo("reduction with modifier");
}
}
};
@@ -617,17 +601,47 @@ findReductionVarTypes(LLVM::ModuleTranslation &moduleTranslation) {
}
// Scan reduction requires a shared buffer to be allocated to perform reduction.
-// ParallelAllocaStackFrame holds the allocaIP where shared allocation can be
-// done.
+// The allocation needs to be done outside the parallel region where scan
+// operation is used.
static llvm::OpenMPIRBuilder::InsertPointTy
-findParallelAllocaIP(LLVM::ModuleTranslation &moduleTranslation) {
- llvm::OpenMPIRBuilder::InsertPointTy parallelAllocaIP;
- moduleTranslation.stackWalk<OpenMPParallelAllocaStackFrame>(
- [&](OpenMPParallelAllocaStackFrame &frame) {
- parallelAllocaIP = frame.allocaInsertPoint;
- return WalkResult::interrupt();
+findParallelAllocaIP(llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ // If there is an alloca insertion point on stack, i.e. we are in a nested
+ // operation and a specific point was provided by some surrounding operation,
+ // use it.
+ llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
+ WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
+ [&](OpenMPAllocaStackFrame &frame) {
+ if (frame.containsParallelOp) {
+ allocaInsertPoint = frame.allocaInsertPoint;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::skip();
});
- return parallelAllocaIP;
+ if (walkResult.wasInterrupted())
+ return allocaInsertPoint;
+ // Otherwise, insert to the entry block of the surrounding function.
+ // If the current IRBuilder InsertPoint is the function's entry, it cannot
+ // also be used for alloca insertion which would result in insertion order
+ // confusion. Create a new BasicBlock for the Builder and use the entry block
+ // for the allocs.
+ // TODO: Create a dedicated alloca BasicBlock at function creation such that
+ // we do not need to move the current InertPoint here.
+ if (builder.GetInsertBlock() ==
+ &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
+ assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
+ "Assuming end of basic block");
+ llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
+ builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
+ builder.GetInsertBlock()->getNextNode());
+ builder.CreateBr(entryBB);
+ builder.SetInsertPoint(entryBB);
+ }
+
+ llvm::BasicBlock &funcEntryBlock =
+ builder.GetInsertBlock()->getParent()->getEntryBlock();
+ return llvm::OpenMPIRBuilder::InsertPointTy(
+ &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
}
/// Converts the given region that appears within an OpenMP dialect operation to
@@ -2808,7 +2822,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
convertToScheduleKind(schedule), chunk, isSimd,
scheduleMod == omp::ScheduleModifier::monotonic,
scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
- workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk));
+ workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
if (failed(handleError(wsloopIP, opInst)))
return failure();
@@ -2840,9 +2854,12 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
SmallVector<OwningReductionGen> owningReductionGens;
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+ SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens;
+ ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
owningReductionGens, owningAtomicReductionGens,
- privateReductionVariables, reductionInfos);
+ owningReductionGenRefDataPtrGens,
+ privateReductionVariables, reductionInfos, isByRef);
llvm::BasicBlock *cont = splitBB(builder, false, "omp.scan.loop.cont");
llvm::ScanInfo *scanInfo = findScanInfo(moduleTranslation);
llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP =
@@ -2853,7 +2870,6 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
builder.restoreIP(*redIP);
builder.CreateBr(cont);
->>>>>>> e715dc85eb45 ([MLIR][OpenMP] Add scan reduction lowering to llvm)
}
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
@@ -2934,6 +2950,15 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
opInst.getNumReductionVars());
SmallVector<DeferredStore> deferredStores;
+ bool foundParallelOp = false;
+ moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
+ [&](OpenMPAllocaStackFrame &frame) {
+ if (foundParallelOp) {
+ frame.containsParallelOp = true;
+ return WalkResult::interrupt();
+ }
+ foundParallelOp = true;
+ });
auto bodyGenCB = [&](InsertPointTy allocaIP,
InsertPointTy codeGenIP) -> llvm::Error {
llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
@@ -3088,8 +3113,6 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
- LLVM::ModuleTranslation::SaveStack<OpenMPParallelAllocaStackFrame> frame(
- moduleTranslation, allocaIP);
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
@@ -3241,9 +3264,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
findCurrentLoopInfos(moduleTranslation);
- llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation).front();
- SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
- findCurrentLoopInfos(moduleTranslation);
+ llvm::CanonicalLoopInfo *loopInfo = loopInfos.front();
// Emit Initialization for linear variables
if (simdOp.getLinearVars().size()) {
linearClauseProcessor.initLinearVar(builder, moduleTranslation,
@@ -3336,7 +3357,7 @@ convertOmpScan(Operation &opInst, llvm::IRBuilderBase &builder,
llvmScanVarsType.push_back((*reductionVarToType)[llvmVal]);
}
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
- findParallelAllocaIP(moduleTranslation);
+ findParallelAllocaIP(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::ScanInfo *scanInfo = findScanInfo(moduleTranslation);
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
index ed04a069b998f..72211373debf2 100644
--- a/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
@@ -12,14 +12,9 @@ omp.declare_reduction @add_reduction_i32 : i32 init {
llvm.func @scan_reduction() {
%0 = llvm.mlir.constant(1 : i64) : i64
%1 = llvm.alloca %0 x i32 {bindc_name = "z"} : (i64) -> !llvm.ptr
- %2 = llvm.mlir.constant(1 : i64) : i64
- %3 = llvm.alloca %2 x i32 {bindc_name = "y"} : (i64) -> !llvm.ptr
- %4 = llvm.mlir.constant(1 : i64) : i64
- %5 = llvm.alloca %4 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
- %6 = llvm.mlir.constant(1 : i64) : i64
- %7 = llvm.alloca %6 x i32 {bindc_name = "k"} : (i64) -> !llvm.ptr
- %8 = llvm.mlir.constant(0 : index) : i64
- %9 = llvm.mlir.constant(1 : index) : i64
+ %3 = llvm.alloca %0 x i32 {bindc_name = "y"} : (i64) -> !llvm.ptr
+ %5 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+ %7 = llvm.alloca %0 x i32 {bindc_name = "k"} : (i64) -> !llvm.ptr
%10 = llvm.mlir.constant(100 : i32) : i32
%11 = llvm.mlir.constant(1 : i32) : i32
%12 = llvm.mlir.constant(0 : i32) : i32
@@ -62,10 +57,6 @@ llvm.mlir.global internal @_QFEb() {addr_space = 0 : i32} : !llvm.array<100 x i3
%0 = llvm.mlir.zero : !llvm.array<100 x i32>
llvm.return %0 : !llvm.array<100 x i32>
}
-llvm.mlir.global internal constant @_QFECn() {addr_space = 0 : i32} : i32 {
- %0 = llvm.mlir.constant(100 : i32) : i32
- llvm.return %0 : i32
-}
//CHECK: %vla = alloca ptr, align 8
//CHECK: omp_parallel
//CHECK: store ptr %vla, ptr %gep_vla, align 8
diff --git a/openmp/runtime/test/scan/scan.f90 b/openmp/runtime/test/scan/scan.f90
index 1bea5c7e7a481..76263508105d0 100644
--- a/openmp/runtime/test/scan/scan.f90
+++ b/openmp/runtime/test/scan/scan.f90
@@ -21,5 +21,18 @@ program inclusive_scan
end do
print *,'x =', x
+ do k = 1, 10
+ print *, 'b(', k, ') =', b(k)
+ end do
end program
!CHECK: x = 5050
+!CHECK: b( 1 ) = 1
+!CHECK: b( 2 ) = 3
+!CHECK: b( 3 ) = 6
+!CHECK: b( 4 ) = 10
+!CHECK: b( 5 ) = 15
+!CHECK: b( 6 ) = 21
+!CHECK: b( 7 ) = 28
+!CHECK: b( 8 ) = 36
+!CHECK: b( 9 ) = 45
+!CHECK: b( 10 ) = 55
>From b10e5bb15e86bdd5591d99db0a1443f5750188ac Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Wed, 7 Jan 2026 13:38:24 -0600
Subject: [PATCH 5/5] Removing the vector to represent loopInfo
---
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 109 ++++++++++--------
1 file changed, 62 insertions(+), 47 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 80cfc21364b34..708eee6959f8b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -95,8 +95,10 @@ class OpenMPLoopInfoStackFrame
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
/// For constructs like scan, one LoopInfo frame can contain multiple
/// Canonical Loops as a single openmpLoopNestOp will be split into input
- /// loop and scan loop.
- SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
+ /// loop and scan loop. In case of scan loop, loopInfo holds the input
+ /// loop info and scanInfo holds the scan loop info
+ llvm::CanonicalLoopInfo *loopInfo = nullptr;
+ llvm::CanonicalLoopInfo *scanloopInfo = nullptr;
llvm::ScanInfo *scanInfo = nullptr;
/// Map reduction variables to their LLVM types.
std::unique_ptr<llvm::DenseMap<llvm::Value *, llvm::Type *>>
@@ -402,10 +404,13 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (!op.getReductionVars().empty() || op.getReductionByref() ||
op.getReductionSyms())
result = todo("reduction");
- if (op.getReductionMod()) {
+ if (op.getReductionMod() &&
+ op.getReductionMod().value() != omp::ReductionModifier::defaultmod) {
if (isa<omp::WsloopOp>(op)) {
if (op.getReductionMod().value() == omp::ReductionModifier::task)
result = todo("reduction with task modifier");
+ } else {
+ result = todo("reduction with modifier");
}
}
};
@@ -559,15 +564,30 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
/// Find the loop information structure for the loop nest being translated. It
/// will return a `null` value unless called from the translation function for
/// a loop wrapper operation after successfully translating its body.
-static SmallVector<llvm::CanonicalLoopInfo *>
-findCurrentLoopInfos(LLVM::ModuleTranslation &moduleTranslation) {
- SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
+static llvm::CanonicalLoopInfo *
+findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::CanonicalLoopInfo *loopInfo;
+ moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+ [&](OpenMPLoopInfoStackFrame &frame) {
+ loopInfo = frame.loopInfo;
+ return WalkResult::interrupt();
+ });
+ return loopInfo;
+}
+
+/// Find the scan loop information structure for the scan loop nest being
+/// translated. It will return a `null` value unless called from the translation
+/// function for a loop wrapper operation after successfully translating its
+/// body.
+static llvm::CanonicalLoopInfo *
+findCurrentScanLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::CanonicalLoopInfo *scanLoopInfo;
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
[&](OpenMPLoopInfoStackFrame &frame) {
- loopInfos = frame.loopInfos;
+ scanLoopInfo = frame.scanloopInfo;
return WalkResult::interrupt();
});
- return loopInfos;
+ return scanLoopInfo;
}
// LoopFrame stores the scaninfo which is used for scan reduction.
@@ -2796,9 +2816,6 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(regionBlock, opInst)))
return failure();
- SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
- findCurrentLoopInfos(moduleTranslation);
-
const auto &&wsloopCodeGen = [&](llvm::CanonicalLoopInfo *loopInfo,
bool noLoopMode, bool inputScanLoop) {
// Emit Initialization and Update IR for linear variables
@@ -2830,7 +2847,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
// Emit finalization and in-place rewrites for linear vars.
if (!wsloopOp.getLinearVars().empty()) {
llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
- if (loopInfo->getLastIter())
+ if (!loopInfo->getLastIter())
return failure();
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
@@ -2848,12 +2865,13 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
return llvm::success();
};
+ llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
if (isInScanRegion) {
- auto inputLoopFinishIp = loopInfos.front()->getAfterIP();
+ auto inputLoopFinishIp = loopInfo->getAfterIP();
builder.restoreIP(inputLoopFinishIp);
SmallVector<OwningReductionGen> owningReductionGens;
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
- SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo, 2> reductionInfos;
SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens;
ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
@@ -2897,14 +2915,16 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
// For Scan loops input loop need not pop cancellation CB and hence, it is set
// false for the first loop
bool inputScanLoop = isInScanRegion;
- for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
- // TODO: Linear clause support needs to be enabled for scan reduction.
- if (failed(wsloopCodeGen(loopInfo, noLoopMode, inputScanLoop)))
- return failure();
- inputScanLoop = false;
- }
+ // TODO: Linear clause support needs to be enabled for scan reduction.
+ if (failed(wsloopCodeGen(loopInfo, noLoopMode, inputScanLoop)))
+ return failure();
+ inputScanLoop = false;
if (isInScanRegion) {
+ llvm::CanonicalLoopInfo *scanLoopInfo =
+ findCurrentScanLoopInfo(moduleTranslation);
+ if (failed(wsloopCodeGen(scanLoopInfo, noLoopMode, inputScanLoop)))
+ return failure();
SmallVector<Region *> reductionRegions;
llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
[](omp::DeclareReductionOp reductionDecl) {
@@ -2958,6 +2978,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
return WalkResult::interrupt();
}
foundParallelOp = true;
+ return WalkResult::skip();
});
auto bodyGenCB = [&](InsertPointTy allocaIP,
InsertPointTy codeGenIP) -> llvm::Error {
@@ -3262,9 +3283,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(regionBlock, opInst)))
return failure();
- SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
- findCurrentLoopInfos(moduleTranslation);
- llvm::CanonicalLoopInfo *loopInfo = loopInfos.front();
+ llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
// Emit Initialization for linear variables
if (simdOp.getLinearVars().size()) {
linearClauseProcessor.initLinearVar(builder, moduleTranslation,
@@ -3274,13 +3293,11 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
loopInfo->getIndVar());
}
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
- for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
- ompBuilder->applySimd(
- loopInfo, alignedVars,
- simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr())
- : nullptr,
- order, simdlen, safelen);
- }
+ ompBuilder->applySimd(loopInfo, alignedVars,
+ simdOp.getIfExpr()
+ ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+ : nullptr,
+ order, simdlen, safelen);
linearClauseProcessor.emitStoresForLinearVar(builder);
for (size_t index = 0; index < simdOp.getLinearVars().size(); index++)
@@ -3458,8 +3475,8 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::CanonicalLoopInfo *scanLoop = loopResults.get().back();
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
[&](OpenMPLoopInfoStackFrame &frame) {
- frame.loopInfos.push_back(inputLoop);
- frame.loopInfos.push_back(scanLoop);
+ frame.loopInfo = inputLoop;
+ frame.scanloopInfo = scanLoop;
return WalkResult::interrupt();
});
builder.restoreIP(scanLoop->getAfterIP());
@@ -3514,7 +3531,7 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
assert(newTopLoopInfo && "New top loop information is missing");
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
[&](OpenMPLoopInfoStackFrame &frame) {
- frame.loopInfos.push_back(newTopLoopInfo);
+ frame.loopInfo = newTopLoopInfo;
return WalkResult::interrupt();
});
@@ -5633,20 +5650,18 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::Value *chunk = moduleTranslation.lookupValue(
distributeOp.getDistScheduleChunkSize());
- SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
- findCurrentLoopInfos(moduleTranslation);
- for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
- ompBuilder->applyWorkshareLoop(
- ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
- convertToScheduleKind(schedule), chunk, isSimd,
- scheduleMod == omp::ScheduleModifier::monotonic,
- scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
- workshareLoopType);
-
- if (!wsloopIP)
- return wsloopIP.takeError();
- }
+ llvm::CanonicalLoopInfo *loopInfo =
+ findCurrentLoopInfo(moduleTranslation);
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+ ompBuilder->applyWorkshareLoop(
+ ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+ convertToScheduleKind(schedule), chunk, isSimd,
+ scheduleMod == omp::ScheduleModifier::monotonic,
+ scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+ workshareLoopType, false, hasDistSchedule, chunk);
+
+ if (!wsloopIP)
+ return wsloopIP.takeError();
}
if (failed(cleanupPrivateVars(builder, moduleTranslation,
distributeOp.getLoc(), privVarsInfo.llvmVars,
More information about the Openmp-commits
mailing list