[flang-commits] [flang] [mlir] WIP : Delayed privatisation of index variables (PR #75836)
Kiran Chandramohan via flang-commits
flang-commits at lists.llvm.org
Mon Dec 18 09:56:20 PST 2023
https://github.com/kiranchandramohan created https://github.com/llvm/llvm-project/pull/75836
Not for merge.
This patch is a WIP to delay privatisation of index variables.
>From f2089f8e44963dae9e8af47612f13e02b1eb168c Mon Sep 17 00:00:00 2001
From: Raghu Maddhipatla <Raghu.Maddhipatla at amd.com>
Date: Wed, 25 Oct 2023 18:19:08 -0500
Subject: [PATCH 1/2] [Flang] [OpenMP] [Semantics] [MLIR] [Lowering] Add
lowering support for IS_DEVICE_PTR and HAS_DEVICE_ADDR clauses on OMP TARGET
directive.
---
flang/lib/Lower/OpenMP.cpp | 54 +++++++++++++++++--
flang/test/Lower/OpenMP/FIR/target.f90 | 41 +++++++++++++-
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 18 +++++--
mlir/test/Dialect/OpenMP/ops.mlir | 8 +--
4 files changed, 110 insertions(+), 11 deletions(-)
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index eeba87fcd15116..f0d6a3f382e5ff 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -604,6 +604,18 @@ class ClauseProcessor {
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
&useDeviceSymbols) const;
+ bool
+ processIsDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+ llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ &isDeviceSymbols) const;
+ bool
+ processHasDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+ llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ &isDeviceSymbols) const;
// Call this method for these clauses that should be supported but are not
// implemented yet. It triggers a compilation error if any of the given
@@ -1890,6 +1902,34 @@ bool ClauseProcessor::processUseDevicePtr(
});
}
+bool ClauseProcessor::processIsDevicePtr(
+ llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+ llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
+ const {
+ return findRepeatableClause<ClauseTy::IsDevicePtr>(
+ [&](const ClauseTy::IsDevicePtr *devPtrClause,
+ const Fortran::parser::CharBlock &) {
+ addUseDeviceClause(converter, devPtrClause->v, operands, isDeviceTypes,
+ isDeviceLocs, isDeviceSymbols);
+ });
+}
+
+bool ClauseProcessor::processHasDeviceAddr(
+ llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+ llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
+ const {
+ return findRepeatableClause<ClauseTy::HasDeviceAddr>(
+ [&](const ClauseTy::HasDeviceAddr *devAddrClause,
+ const Fortran::parser::CharBlock &) {
+ addUseDeviceClause(converter, devAddrClause->v, operands, isDeviceTypes,
+ isDeviceLocs, isDeviceSymbols);
+ });
+}
+
template <typename... Ts>
void ClauseProcessor::processTODO(mlir::Location currentLocation,
llvm::omp::Directive directive) const {
@@ -2617,6 +2657,10 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Type> mapSymTypes;
llvm::SmallVector<mlir::Location> mapSymLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
+ llvm::SmallVector<mlir::Value> devicePtrOperands, deviceAddrOperands;
+ llvm::SmallVector<mlir::Type> useDeviceTypes;
+ llvm::SmallVector<mlir::Location> useDeviceLocs;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
ClauseProcessor cp(converter, clauseList);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
@@ -2626,11 +2670,13 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
cp.processNowait(nowaitAttr);
cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
mapOperands, &mapSymTypes, &mapSymLocs, &mapSymbols);
+ cp.processIsDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
+ useDeviceSymbols);
+ cp.processHasDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs,
+ useDeviceSymbols);
cp.processTODO<Fortran::parser::OmpClause::Private,
Fortran::parser::OmpClause::Depend,
Fortran::parser::OmpClause::Firstprivate,
- Fortran::parser::OmpClause::IsDevicePtr,
- Fortran::parser::OmpClause::HasDeviceAddr,
Fortran::parser::OmpClause::Reduction,
Fortran::parser::OmpClause::InReduction,
Fortran::parser::OmpClause::Allocate,
@@ -2705,7 +2751,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
- nowaitAttr, mapOperands);
+ nowaitAttr, devicePtrOperands, deviceAddrOperands, mapOperands);
genBodyOfTargetOp(converter, eval, targetOp, mapSymTypes, mapSymLocs,
mapSymbols, currentLocation);
@@ -3101,6 +3147,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
!std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
+ !std::get_if<Fortran::parser::OmpClause::IsDevicePtr>(&clause.u) &&
+ !std::get_if<Fortran::parser::OmpClause::HasDeviceAddr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u)) {
TODO(clauseLocation, "OpenMP Block construct clause");
diff --git a/flang/test/Lower/OpenMP/FIR/target.f90 b/flang/test/Lower/OpenMP/FIR/target.f90
index 2034ac84334e54..8f38261fc1aa62 100644
--- a/flang/test/Lower/OpenMP/FIR/target.f90
+++ b/flang/test/Lower/OpenMP/FIR/target.f90
@@ -411,4 +411,43 @@ subroutine omp_target_parallel_do
!CHECK: omp.terminator
!CHECK: }
!$omp end target parallel do
- end subroutine omp_target_parallel_do
+end subroutine omp_target_parallel_do
+
+!===============================================================================
+! Target `is_device_ptr` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_is_device_ptr() {
+subroutine omp_target_is_device_ptr
+ use iso_c_binding, only : c_ptr, c_loc
+ !CHECK: %[[DEV_PTR:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "a", uniq_name = "_QFomp_target_is_device_ptrEa"}
+ type(c_ptr) :: a
+ !CHECK %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "b", fir.target, uniq_name = "_QFomp_target_is_device_ptrEb"}
+ integer, target :: b
+ !CHECK: %[[MAP_0:.*]] = omp.map_info var_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>> {name = "a"}
+ !CHECK: %[[MAP_1:.*]] = omp.map_info var_ptr(%[[VAL_0:.*]] : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "b"}
+ !CHECK: omp.target is_device_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) map_entries(%[[MAP_0:.*]], %[[MAP_1:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<i32>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
+ !$omp target map(tofrom: a,b) is_device_ptr(a)
+ !CHECK: {{.*}} = fir.coordinate_of %[[DEV_PTR:.*]], {{.*}} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
+ a = c_loc(b)
+ !CHECK: omp.terminator
+ !$omp end target
+ !CHECK: }
+end subroutine omp_target_is_device_ptr
+
+ !===============================================================================
+ ! Target `has_device_addr` clause
+ !===============================================================================
+
+ !CHECK-LABEL: func.func @_QPomp_target_has_device_addr() {
+ subroutine omp_target_has_device_addr
+ integer, pointer :: a
+ !CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.ptr<i32>> {bindc_name = "a", uniq_name = "_QFomp_target_has_device_addrEa"}
+ !CHECK: omp.target has_device_addr(%[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>) {
+ !$omp target has_device_addr(a)
+ !CHECK: {{.*}} = fir.load %[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
+ a = 10
+ !CHECK: omp.terminator
+ !$omp end target
+ !CHECK: }
+end subroutine omp_target_has_device_addr
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 8ff5380f71ad45..92ab32b0131fa6 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1389,10 +1389,19 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterfa
The optional $thread_limit specifies the limit on the number of threads
- The optional $nowait elliminates the implicit barrier so the parent task can make progress
+ The optional $nowait eliminates the implicit barrier so the parent task can make progress
even if the target task is not yet completed.
- TODO: is_device_ptr, depend, defaultmap, in_reduction
+ The optional $is_device_ptr indicates list items are device pointers
+
+ The optional $has_device_addr indicates that list items already have device
+ addresses, so may be directly accessed from target device. May include array
+ sections.
+
+ The optional $map_operands maps data from the task’s environment to the
+ device environment.
+
+ TODO: depend, defaultmap, in_reduction
}];
@@ -1400,8 +1409,9 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterfa
Optional<AnyInteger>:$device,
Optional<AnyInteger>:$thread_limit,
UnitAttr:$nowait,
+ Variadic<OpenMP_PointerLikeType>:$is_device_ptr,
+ Variadic<OpenMP_PointerLikeType>:$has_device_addr,
Variadic<AnyType>:$map_operands);
-
let regions = (region AnyRegion:$region);
let assemblyFormat = [{
@@ -1409,6 +1419,8 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterfa
| `device` `(` $device `:` type($device) `)`
| `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
| `nowait` $nowait
+ | `is_device_ptr` `(` $is_device_ptr `:` type($is_device_ptr) `)`
+ | `has_device_addr` `(` $has_device_addr `:` type($has_device_addr) `)`
| `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
) $region attr-dict
}];
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 4d88d9ac86fe16..b153b1b8221d80 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -480,22 +480,22 @@ func.func @omp_simdloop_pretty_multiple(%lb1 : index, %ub1 : index, %step1 : ind
}
// CHECK-LABEL: omp_target
-func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
+func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %device_ptr: memref<i32>, %device_addr: memref<?xi32>, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
// Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait.
// CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait
"omp.target"(%if_cond, %device, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
- }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0>} : ( i1, si32, i32 ) -> ()
+ }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : ( i1, si32, i32 ) -> ()
// Test with optional map clause.
// CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
// CHECK: %[[MAP_B:.*]] = omp.map_info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
- // CHECK: omp.target map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
+ // CHECK: omp.target is_device_ptr(%[[VAL_4:.*]] : memref<i32>) has_device_addr(%[[VAL_5:.*]] : memref<?xi32>) map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
%mapv1 = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
%mapv2 = omp.map_info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
- omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) {
+ omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) is_device_ptr(%device_ptr : memref<i32>) has_device_addr(%device_addr : memref<?xi32>) {
^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>):
omp.terminator
}
>From f6c8748647fd4160f43dc7eeb995f59fc235b89e Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiran.chandramohan at arm.com>
Date: Mon, 11 Dec 2023 22:51:12 +0000
Subject: [PATCH 2/2] [Flang][MLIR][OpenMP] WIP: Privatisation for index
variables
---
flang/lib/Lower/OpenMP.cpp | 118 +++++++++++++++++-
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 2 +
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 73 ++++++++++-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 20 +++
.../OpenMPToLLVM/convert-to-llvmir.mlir | 2 +-
mlir/test/Dialect/OpenMP/ops.mlir | 10 +-
mlir/test/Target/LLVMIR/openmp-llvm.mlir | 6 +-
7 files changed, 217 insertions(+), 14 deletions(-)
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index f0d6a3f382e5ff..428b4b171bbabe 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -2115,6 +2115,7 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
mlir::Type tempTy = converter.genType(*sym);
+ llvm::outs() << "Temp type = " << tempTy << "\n";
mlir::Value temp = firOpBuilder.create<fir::AllocaOp>(
loc, tempTy, /*pinned=*/true, /*lengthParams=*/mlir::ValueRange{},
/*shapeParams*/ mlir::ValueRange{},
@@ -2128,6 +2129,103 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
return storeOp;
}
+/// Create the body (block) for an OpenMP Loop Operation.
+///
+/// \param [in] op - the operation the body belongs to.
+/// \param [inout] converter - converter to use for the clauses.
+/// \param [in] loc - location in source code.
+/// \param [in] eval - current PFT node/evaluation.
+/// \oaran [in] clauses - list of clauses to process.
+/// \param [in] args - block arguments (induction variable[s]) for the
+//// region.
+/// \param [in] outerCombined - is this an outer operation - prevents
+/// privatization.
+template <typename Op>
+static void createBodyOfLoopOp(
+ Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OmpClauseList *clauses = nullptr,
+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &args = {},
+ bool outerCombined = false, DataSharingProcessor *dsp = nullptr) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ // If an argument for the region is provided then create the block with that
+ // argument. Also update the symbol's address with the mlir argument value.
+ // e.g. For loops the argument is the induction variable. And all further
+ // uses of the induction variable should use this mlir value.
+ mlir::Operation *storeOp = nullptr;
+ assert(args.size() > 0);
+ std::size_t loopVarTypeSize = 0;
+ for (const Fortran::semantics::Symbol *arg : args)
+ loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
+ mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
+ llvm::SmallVector<mlir::Type> tiv;
+ llvm::SmallVector<mlir::Location> locs;
+ for (int i = 0; i < (int)args.size(); i++) {
+ tiv.push_back(loopVarType);
+ locs.push_back(loc);
+ }
+ int offset = 0;
+ // The argument is not currently in memory, so make a temporary for the
+ // argument, and store it there, then bind that location to the argument.
+ for (const Fortran::semantics::Symbol *arg : args) {
+ mlir::Type symType = converter.genType(*arg);
+ mlir::Type symRefType = firOpBuilder.getRefType(symType);
+ tiv.push_back(symRefType);
+ locs.push_back(loc);
+ offset++;
+ }
+ firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
+
+ int argIndex = 0;
+ for (const Fortran::semantics::Symbol *arg : args) {
+ mlir::Value addrVal =
+ fir::getBase(op.getRegion().front().getArgument(argIndex+offset));
+ converter.bindSymbol(*arg, addrVal);
+ mlir::Type symType = converter.genType(*arg);
+ mlir::Value indexVal =
+ fir::getBase(op.getRegion().front().getArgument(argIndex));
+ mlir::Value cvtVal = firOpBuilder.createConvert(loc, symType, indexVal);
+ addrVal = converter.getSymbolAddress(*arg);
+ storeOp = firOpBuilder.create<fir::StoreOp>(loc, cvtVal, addrVal);
+ argIndex++;
+ }
+ // Set the insert for the terminator operation to go at the end of the
+ // block - this is either empty or the block with the stores above,
+ // the end of the block works for both.
+ mlir::Block &block = op.getRegion().back();
+ firOpBuilder.setInsertionPointToEnd(&block);
+
+ // If it is an unstructured region and is not the outer region of a combined
+ // construct, create empty blocks for all evaluations.
+ if (eval.lowerAsUnstructured() && !outerCombined)
+ Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp,
+ mlir::omp::YieldOp>(
+ firOpBuilder, eval.getNestedEvaluations());
+
+ // Insert the terminator.
+ Fortran::lower::genOpenMPTerminator(firOpBuilder, op.getOperation(), loc);
+ // Reset the insert point to before the terminator.
+ resetBeforeTerminator(firOpBuilder, storeOp, block);
+
+ // Handle privatization. Do not privatize if this is the outer operation.
+ if (clauses && !outerCombined) {
+ constexpr bool isLoop = std::is_same_v<Op, mlir::omp::WsLoopOp> ||
+ std::is_same_v<Op, mlir::omp::SimdLoopOp>;
+ if (!dsp) {
+ DataSharingProcessor proc(converter, *clauses, eval);
+ proc.processStep1();
+ proc.processStep2(op, isLoop);
+ } else {
+ if (isLoop && args.size() > 0)
+ dsp->setLoopIV(converter.getSymbolAddress(*args[0]));
+ dsp->processStep2(op, isLoop);
+ }
+
+ if (storeOp)
+ firOpBuilder.setInsertionPointAfter(storeOp);
+ }
+}
+
/// Create the body (block) for an OpenMP Operation.
///
/// \param [in] op - the operation the body belongs to.
@@ -2960,7 +3058,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, linearVars,
- linearStepVars, reductionVars;
+ linearStepVars, privateVars, reductionVars;
mlir::Value scheduleChunkClauseOperand;
mlir::IntegerAttr orderedClauseOperand;
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
@@ -3069,9 +3167,23 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
return;
}
+ // Collect the loops to collapse.
+ Fortran::lower::pft::Evaluation *doConstructEval =
+ &eval.getFirstNestedEvaluation();
+ Fortran::lower::pft::Evaluation *doLoop =
+ &doConstructEval->getFirstNestedEvaluation();
+ auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
+ assert(doStmt && "Expected do loop to be in the nested evaluation");
+ const auto &loopControl =
+ std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
+ const Fortran::parser::LoopControl::Bounds *bounds =
+ std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
+ assert(bounds && "Expected bounds for worksharing do loop");
+ privateVars.push_back(converter.getSymbolAddress(*bounds->name.thing.symbol));
+
auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars,
- reductionVars,
+ privateVars, reductionVars,
reductionDeclSymbols.empty()
? nullptr
: mlir::ArrayAttr::get(firOpBuilder.getContext(),
@@ -3107,7 +3219,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
wsLoopOp.setNowaitAttr(nowaitClauseOperand);
}
- createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, currentLocation,
+ createBodyOfLoopOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, currentLocation,
eval, &loopOpClauseList, iv,
/*outer=*/false, &dsp);
}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 92ab32b0131fa6..b19efcaa867676 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -479,6 +479,7 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
Variadic<IntLikeType>:$step,
Variadic<AnyType>:$linear_vars,
Variadic<I32>:$linear_step_vars,
+ Variadic<OpenMP_PointerLikeType>:$privates,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<SymbolRefArrayAttr>:$reductions,
OptionalAttr<ScheduleKindAttr>:$schedule_val,
@@ -517,6 +518,7 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
|`nowait` $nowait
|`ordered` `(` $ordered_val `)`
|`order` `(` custom<ClauseAttr>($order_val) `)`
+ |`private` `(` custom<PrivateEntries>($privates, type($privates)) `)`
|`reduction` `(`
custom<ReductionVarList>(
$reduction_vars, type($reduction_vars), $reductions
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 20df0099cbd24d..3da37231e3c52c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -178,6 +178,67 @@ void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
p << stringifyEnum(attr.getValue());
}
+static ParseResult
+parsePrivateEntries(OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateOperands,
+ SmallVectorImpl<Type> &privateOperandTypes) {
+ OpAsmParser::UnresolvedOperand arg;
+ OpAsmParser::UnresolvedOperand blockArg;
+ Type argType;
+ auto parseEntries = [&]() -> ParseResult {
+ if (parser.parseOperand(arg) || parser.parseArrow() ||
+ parser.parseOperand(blockArg))
+ return failure();
+ privateOperands.push_back(arg);
+ return success();
+ };
+
+ auto parseTypes = [&]() -> ParseResult {
+ if (parser.parseType(argType))
+ return failure();
+ privateOperandTypes.push_back(argType);
+ return success();
+ };
+
+ if (parser.parseCommaSeparatedList(parseEntries))
+ return failure();
+
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseCommaSeparatedList(parseTypes))
+ return failure();
+
+ return success();
+}
+
+static void printPrivateEntries(OpAsmPrinter &p, Operation *op,
+ OperandRange privateOperands,
+ TypeRange privateOperandTypes) {
+ auto ®ion = op->getRegion(0);
+
+ unsigned argIndex = 0;
+ unsigned offset = 0;
+ if (auto wsLoop = dyn_cast<WsLoopOp>(op))
+ offset = wsLoop.getNumLoops();
+ for (const auto &privOperand : privateOperands) {
+ const auto &blockArg = region.front().getArgument(argIndex+offset);
+ p << privOperand << " -> " << blockArg;
+ argIndex++;
+ if (argIndex < privateOperands.size())
+ p << ", ";
+ }
+ p << " : ";
+
+ argIndex = 0;
+ for (const auto &privOperandType : privateOperandTypes) {
+ p << privOperandType;
+ argIndex++;
+ if (argIndex < privateOperands.size())
+ p << ", ";
+ }
+}
+
//===----------------------------------------------------------------------===//
// Parser and printer for Linear Clause
//===----------------------------------------------------------------------===//
@@ -1086,7 +1147,14 @@ void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion,
ValueRange steps, TypeRange loopVarTypes,
UnitAttr inclusive) {
auto args = region.front().getArguments();
- p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound
+ p << " (";
+ unsigned numLoops = steps.size();
+ for (unsigned i=0; i<numLoops; i++) {
+ if (i != 0)
+ p << ", ";
+ p << args[i];
+ }
+ p << ") : " << args[0].getType() << " = (" << lowerBound
<< ") to (" << upperBound << ") ";
if (inclusive)
p << "inclusive ";
@@ -1269,7 +1337,8 @@ void WsLoopOp::build(OpBuilder &builder, OperationState &state,
ValueRange step, ArrayRef<NamedAttribute> attributes) {
build(builder, state, lowerBound, upperBound, step,
/*linear_vars=*/ValueRange(),
- /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(),
+ /*linear_step_vars=*/ValueRange(), /*private_vars=*/ValueRange(),
+ /*reduction_vars=*/ValueRange(),
/*reductions=*/nullptr, /*schedule_val=*/nullptr,
/*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
/*simd_modifier=*/false, /*nowait=*/false, /*ordered_val=*/nullptr,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 4f6200d29a70a6..26ec12e427ff35 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -833,6 +833,24 @@ static void collectReductionInfo(
}
}
+/// Allocate space for privatized reduction variables.
+void
+allocPrivatizationVars(omp::WsLoopOp loop, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ llvm::OpenMPIRBuilder::InsertPointTy &allocaIP) {
+ unsigned offset = loop.getNumLoops();
+ unsigned numArgs = loop.getRegion().front().getNumArguments();
+ llvm::IRBuilderBase::InsertPointGuard guard(builder);
+ builder.restoreIP(allocaIP);
+ for (unsigned i = offset; i < numArgs; ++i) {
+ if (auto op = loop.getPrivates()[i-offset].getDefiningOp<LLVM::AllocaOp>()) {
+ llvm::Value *var = builder.CreateAlloca(moduleTranslation.convertType(op.getResultPtrElementType()));
+ // moduleTranslation.convertType(loop.getPrivates()[i-offset].getType()));
+ moduleTranslation.mapValue(loop.getRegion().front().getArgument(i), var);
+ }
+ }
+}
+
/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -861,6 +879,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
+ allocPrivatizationVars(loop, builder, moduleTranslation, allocaIP);
+
SmallVector<llvm::Value *> privateReductionVariables;
DenseMap<Value, llvm::Value *> reductionVariableMap;
allocReductionVars(loop, builder, moduleTranslation, allocaIP, reductionDecls,
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index 3fbeaebb592a4d..86576831f67dff 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -79,7 +79,7 @@ func.func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4:
// CHECK: "test.payload"(%[[CAST_ARG6]], %[[CAST_ARG7]]) : (index, index) -> ()
"test.payload"(%arg6, %arg7) : (index, index) -> ()
omp.yield
- }) {operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 0, 0>} : (index, index, index, index, index, index) -> ()
+ }) {operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 0, 0, 0>} : (index, index, index, index, index, index) -> ()
omp.terminator
}
return
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index b153b1b8221d80..134cd4227af485 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -141,7 +141,7 @@ func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memre
"omp.wsloop" (%lb, %ub, %step) ({
^bb0(%iv: index):
omp.yield
- }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>, ordered_val = 1} :
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0,0>, ordered_val = 1} :
(index, index, index) -> ()
// CHECK: omp.wsloop linear(%{{.*}} = %{{.*}} : memref<i32>) schedule(static)
@@ -149,7 +149,7 @@ func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memre
"omp.wsloop" (%lb, %ub, %step, %data_var, %linear_var) ({
^bb0(%iv: index):
omp.yield
- }) {operandSegmentSizes = array<i32: 1,1,1,1,1,0,0>, schedule_val = #omp<schedulekind static>} :
+ }) {operandSegmentSizes = array<i32: 1,1,1,1,1,0,0,0>, schedule_val = #omp<schedulekind static>} :
(index, index, index, memref<i32>, i32) -> ()
// CHECK: omp.wsloop linear(%{{.*}} = %{{.*}} : memref<i32>, %{{.*}} = %{{.*}} : memref<i32>) schedule(static)
@@ -157,7 +157,7 @@ func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memre
"omp.wsloop" (%lb, %ub, %step, %data_var, %data_var, %linear_var, %linear_var) ({
^bb0(%iv: index):
omp.yield
- }) {operandSegmentSizes = array<i32: 1,1,1,2,2,0,0>, schedule_val = #omp<schedulekind static>} :
+ }) {operandSegmentSizes = array<i32: 1,1,1,2,2,0,0,0>, schedule_val = #omp<schedulekind static>} :
(index, index, index, memref<i32>, memref<i32>, i32, i32) -> ()
// CHECK: omp.wsloop linear(%{{.*}} = %{{.*}} : memref<i32>) schedule(dynamic = %{{.*}}) ordered(2)
@@ -165,7 +165,7 @@ func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memre
"omp.wsloop" (%lb, %ub, %step, %data_var, %linear_var, %chunk_var) ({
^bb0(%iv: index):
omp.yield
- }) {operandSegmentSizes = array<i32: 1,1,1,1,1,0,1>, schedule_val = #omp<schedulekind dynamic>, ordered_val = 2} :
+ }) {operandSegmentSizes = array<i32: 1,1,1,1,1,0,0,1>, schedule_val = #omp<schedulekind dynamic>, ordered_val = 2} :
(index, index, index, memref<i32>, i32, i32) -> ()
// CHECK: omp.wsloop schedule(auto) nowait
@@ -173,7 +173,7 @@ func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memre
"omp.wsloop" (%lb, %ub, %step) ({
^bb0(%iv: index):
omp.yield
- }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>, nowait, schedule_val = #omp<schedulekind auto>} :
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0,0>, nowait, schedule_val = #omp<schedulekind auto>} :
(index, index, index) -> ()
return
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 1c02c0265462c2..e0caaf06c03a10 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -310,7 +310,7 @@ llvm.func @wsloop_simple(%arg0: !llvm.ptr) {
llvm.store %3, %4 : f32, !llvm.ptr
omp.yield
// CHECK: call void @__kmpc_for_static_fini(ptr @[[$loc_struct]],
- }) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
+ }) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
omp.terminator
}
llvm.return
@@ -330,7 +330,7 @@ llvm.func @wsloop_inclusive_1(%arg0: !llvm.ptr) {
%4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
llvm.store %3, %4 : f32, !llvm.ptr
omp.yield
- }) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
+ }) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
llvm.return
}
@@ -348,7 +348,7 @@ llvm.func @wsloop_inclusive_2(%arg0: !llvm.ptr) {
%4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
llvm.store %3, %4 : f32, !llvm.ptr
omp.yield
- }) {inclusive, operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
+ }) {inclusive, operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
llvm.return
}
More information about the flang-commits
mailing list