[flang-commits] [flang] [mlir] [Flang][OpenMP][Lower] Use clause operand structures (PR #86802)
Sergio Afonso via flang-commits
flang-commits at lists.llvm.org
Fri Apr 12 03:03:07 PDT 2024
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/86802
>From c9ee93f1c28e9d4518ec842f66f497bf0911c9d5 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 26 Mar 2024 16:31:34 +0000
Subject: [PATCH 1/2] [MLIR][OpenMP] Group clause operands into structures
This patch introduces a set of composable structures grouping the MLIR operands
associated to each OpenMP clause. This makes it easier to keep the MLIR
representation for the same clause consistent throughout all operations that
accept it.
The relevant clause operand structures are grouped into per-operation
structures using a mixin pattern and used to define new operation constructors.
These constructors can be used to avoid having to get the order of a possibly
large list of operands right.
Missing clauses are documented as TODOs, as well as operands which are part of
the relevant operation's operand structure but cannot be attached to the
associated operation yet, due to missing op arguments to its MLIR definition.
A follow-up patch will update Flang lowering to make use of these structures,
simplifying the passing of information from clause processing to operation-
generating functions and also simplifying the creation of operations through
the use of the new operation constructors.
---
.../Dialect/OpenMP/OpenMPClauseOperands.h | 300 ++++++++++++++++++
.../mlir/Dialect/OpenMP/OpenMPDialect.h | 7 +-
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 72 ++++-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 226 ++++++++++++-
4 files changed, 595 insertions(+), 10 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
new file mode 100644
index 00000000000000..6454076f7593b3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -0,0 +1,300 @@
+//===-- OpenMPClauseOperands.h ----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the structures defining MLIR operands associated with each
+// OpenMP clause, and structures grouping the appropriate operands for each
+// construct.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_
+#define MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "llvm/ADT/SmallVector.h"
+
+#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h.inc"
+
+namespace mlir {
+namespace omp {
+
+//===----------------------------------------------------------------------===//
+// Mixin structures defining MLIR operands associated with each OpenMP clause.
+//===----------------------------------------------------------------------===//
+
+struct AlignedClauseOps {
+ llvm::SmallVector<Value> alignedVars;
+ llvm::SmallVector<Attribute> alignmentAttrs;
+};
+
+struct AllocateClauseOps {
+ llvm::SmallVector<Value> allocatorVars, allocateVars;
+};
+
+struct CollapseClauseOps {
+ llvm::SmallVector<Value> loopLBVar, loopUBVar, loopStepVar;
+};
+
+struct CopyprivateClauseOps {
+ llvm::SmallVector<Value> copyprivateVars;
+ llvm::SmallVector<Attribute> copyprivateFuncs;
+};
+
+struct DependClauseOps {
+ llvm::SmallVector<Attribute> dependTypeAttrs;
+ llvm::SmallVector<Value> dependVars;
+};
+
+struct DeviceClauseOps {
+ Value deviceVar;
+};
+
+struct DeviceTypeClauseOps {
+ // The default capture type.
+ DeclareTargetDeviceType deviceType = DeclareTargetDeviceType::any;
+};
+
+struct DistScheduleClauseOps {
+ UnitAttr distScheduleStaticAttr;
+ Value distScheduleChunkSizeVar;
+};
+
+struct DoacrossClauseOps {
+ llvm::SmallVector<Value> doacrossVectorVars;
+ ClauseDependAttr doacrossDependTypeAttr;
+ IntegerAttr doacrossNumLoopsAttr;
+};
+
+struct FinalClauseOps {
+ Value finalVar;
+};
+
+struct GrainsizeClauseOps {
+ Value grainsizeVar;
+};
+
+struct HintClauseOps {
+ IntegerAttr hintAttr;
+};
+
+struct IfClauseOps {
+ Value ifVar;
+};
+
+struct InReductionClauseOps {
+ llvm::SmallVector<Value> inReductionVars;
+ llvm::SmallVector<Attribute> inReductionDeclSymbols;
+};
+
+struct LinearClauseOps {
+ llvm::SmallVector<Value> linearVars, linearStepVars;
+};
+
+struct LoopRelatedOps {
+ UnitAttr loopInclusiveAttr;
+};
+
+struct MapClauseOps {
+ llvm::SmallVector<Value> mapVars;
+};
+
+struct MergeableClauseOps {
+ UnitAttr mergeableAttr;
+};
+
+struct NameClauseOps {
+ StringAttr nameAttr;
+};
+
+struct NogroupClauseOps {
+ UnitAttr nogroupAttr;
+};
+
+struct NontemporalClauseOps {
+ llvm::SmallVector<Value> nontemporalVars;
+};
+
+struct NowaitClauseOps {
+ UnitAttr nowaitAttr;
+};
+
+struct NumTasksClauseOps {
+ Value numTasksVar;
+};
+
+struct NumTeamsClauseOps {
+ Value numTeamsLowerVar, numTeamsUpperVar;
+};
+
+struct NumThreadsClauseOps {
+ Value numThreadsVar;
+};
+
+struct OrderClauseOps {
+ ClauseOrderKindAttr orderAttr;
+};
+
+struct OrderedClauseOps {
+ IntegerAttr orderedAttr;
+};
+
+struct ParallelizationLevelClauseOps {
+ UnitAttr parLevelSimdAttr;
+};
+
+struct PriorityClauseOps {
+ Value priorityVar;
+};
+
+struct PrivateClauseOps {
+ // SSA values that correspond to "original" values being privatized.
+ // They refer to the SSA value outside the OpenMP region from which a clone is
+ // created inside the region.
+ llvm::SmallVector<Value> privateVars;
+ // The list of symbols referring to delayed privatizer ops (i.e. `omp.private`
+ // ops).
+ llvm::SmallVector<Attribute> privatizers;
+};
+
+struct ProcBindClauseOps {
+ ClauseProcBindKindAttr procBindKindAttr;
+};
+
+struct ReductionClauseOps {
+ llvm::SmallVector<Value> reductionVars;
+ llvm::SmallVector<Attribute> reductionDeclSymbols;
+ UnitAttr reductionByRefAttr;
+};
+
+struct SafelenClauseOps {
+ IntegerAttr safelenAttr;
+};
+
+struct ScheduleClauseOps {
+ ClauseScheduleKindAttr scheduleValAttr;
+ ScheduleModifierAttr scheduleModAttr;
+ Value scheduleChunkVar;
+ UnitAttr scheduleSimdAttr;
+};
+
+struct SimdlenClauseOps {
+ IntegerAttr simdlenAttr;
+};
+
+struct TaskReductionClauseOps {
+ llvm::SmallVector<Value> taskReductionVars;
+ llvm::SmallVector<Attribute> taskReductionDeclSymbols;
+};
+
+struct ThreadLimitClauseOps {
+ Value threadLimitVar;
+};
+
+struct UntiedClauseOps {
+ UnitAttr untiedAttr;
+};
+
+struct UseDeviceClauseOps {
+ llvm::SmallVector<Value> useDevicePtrVars, useDeviceAddrVars;
+};
+
+//===----------------------------------------------------------------------===//
+// Structures defining clause operands associated with each OpenMP leaf
+// construct.
+//
+// These mirror the arguments expected by the corresponding OpenMP MLIR ops.
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+template <typename... Mixins>
+struct Clauses : public Mixins... {};
+} // namespace detail
+
+using CriticalClauseOps = detail::Clauses<HintClauseOps, NameClauseOps>;
+
+// TODO `indirect` clause.
+using DeclareTargetClauseOps = detail::Clauses<DeviceTypeClauseOps>;
+
+using DistributeClauseOps =
+ detail::Clauses<AllocateClauseOps, DistScheduleClauseOps, OrderClauseOps,
+ PrivateClauseOps>;
+
+// TODO `filter` clause.
+using MaskedClauseOps = detail::Clauses<>;
+
+using OrderedOpClauseOps = detail::Clauses<DoacrossClauseOps>;
+
+using OrderedRegionClauseOps = detail::Clauses<ParallelizationLevelClauseOps>;
+
+using ParallelClauseOps =
+ detail::Clauses<AllocateClauseOps, IfClauseOps, NumThreadsClauseOps,
+ PrivateClauseOps, ProcBindClauseOps, ReductionClauseOps>;
+
+using SectionsClauseOps = detail::Clauses<AllocateClauseOps, NowaitClauseOps,
+ PrivateClauseOps, ReductionClauseOps>;
+
+// TODO `linear` clause.
+using SimdLoopClauseOps =
+ detail::Clauses<AlignedClauseOps, CollapseClauseOps, IfClauseOps,
+ LoopRelatedOps, NontemporalClauseOps, OrderClauseOps,
+ PrivateClauseOps, ReductionClauseOps, SafelenClauseOps,
+ SimdlenClauseOps>;
+
+using SingleClauseOps = detail::Clauses<AllocateClauseOps, CopyprivateClauseOps,
+ NowaitClauseOps, PrivateClauseOps>;
+
+// TODO `defaultmap`, `has_device_addr`, `is_device_ptr`, `uses_allocators`
+// clauses.
+using TargetClauseOps =
+ detail::Clauses<AllocateClauseOps, DependClauseOps, DeviceClauseOps,
+ IfClauseOps, InReductionClauseOps, MapClauseOps,
+ NowaitClauseOps, PrivateClauseOps, ReductionClauseOps,
+ ThreadLimitClauseOps>;
+
+using TargetDataClauseOps = detail::Clauses<DeviceClauseOps, IfClauseOps,
+ MapClauseOps, UseDeviceClauseOps>;
+
+using TargetEnterExitUpdateDataClauseOps =
+ detail::Clauses<DependClauseOps, DeviceClauseOps, IfClauseOps, MapClauseOps,
+ NowaitClauseOps>;
+
+// TODO `affinity`, `detach` clauses.
+using TaskClauseOps =
+ detail::Clauses<AllocateClauseOps, DependClauseOps, FinalClauseOps,
+ IfClauseOps, InReductionClauseOps, MergeableClauseOps,
+ PriorityClauseOps, PrivateClauseOps, UntiedClauseOps>;
+
+using TaskgroupClauseOps =
+ detail::Clauses<AllocateClauseOps, TaskReductionClauseOps>;
+
+using TaskloopClauseOps =
+ detail::Clauses<AllocateClauseOps, CollapseClauseOps, FinalClauseOps,
+ GrainsizeClauseOps, IfClauseOps, InReductionClauseOps,
+ LoopRelatedOps, MergeableClauseOps, NogroupClauseOps,
+ NumTasksClauseOps, PriorityClauseOps, PrivateClauseOps,
+ ReductionClauseOps, UntiedClauseOps>;
+
+using TaskwaitClauseOps = detail::Clauses<DependClauseOps, NowaitClauseOps>;
+
+using TeamsClauseOps =
+ detail::Clauses<AllocateClauseOps, IfClauseOps, NumTeamsClauseOps,
+ PrivateClauseOps, ReductionClauseOps, ThreadLimitClauseOps>;
+
+using WsloopClauseOps =
+ detail::Clauses<AllocateClauseOps, CollapseClauseOps, LinearClauseOps,
+ LoopRelatedOps, NowaitClauseOps, OrderClauseOps,
+ OrderedClauseOps, PrivateClauseOps, ReductionClauseOps,
+ ScheduleClauseOps>;
+
+} // namespace omp
+} // namespace mlir
+
+#endif // MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
index 23509c5b607016..c656bdc870976f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
@@ -26,11 +26,10 @@
#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.h.inc"
#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc"
-#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc"
-#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc"
-#define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h.inc"
+#include "mlir/Dialect/OpenMP/OpenMPClauseOperands.h"
+
+#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index f33942b3c7c02d..2643348d668698 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -287,7 +287,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
let regions = (region AnyRegion:$region);
let builders = [
- OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
+ OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+ OpBuilder<(ins CArg<"const ParallelClauseOps &">:$clauses)>
];
let extraClassDeclaration = [{
/// Returns the number of reduction variables.
@@ -362,6 +363,10 @@ def TeamsOp : OpenMP_Op<"teams", [
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const TeamsClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(
`num_teams` `(` ( $num_teams_lower^ `:` type($num_teams_lower) )? `to`
@@ -451,6 +456,10 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
let regions = (region SizedRegion<1>:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const SectionsClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist( `reduction` `(`
custom<ReductionVarList>(
@@ -495,6 +504,10 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const SingleClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`allocate` `(`
custom<AllocateAndAllocator>(
@@ -601,6 +614,7 @@ def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
OpBuilder<(ins "ValueRange":$lowerBound, "ValueRange":$upperBound,
"ValueRange":$step,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+ OpBuilder<(ins CArg<"const WsloopClauseOps &">:$clauses)>
];
let regions = (region AnyRegion:$region);
@@ -698,6 +712,11 @@ def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
);
let regions = (region AnyRegion:$region);
+
+ let builders = [
+ OpBuilder<(ins CArg<"const SimdLoopClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`aligned` `(`
custom<AlignedClause>($aligned_vars, type($aligned_vars),
@@ -781,6 +800,10 @@ def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments,
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const DistributeClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`dist_schedule_static` $dist_schedule_static
|`chunk_size` `(` $chunk_size `:` type($chunk_size) `)`
@@ -883,6 +906,9 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
Variadic<AnyType>:$allocate_vars,
Variadic<AnyType>:$allocators_vars);
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const TaskClauseOps &">:$clauses)>
+ ];
let assemblyFormat = [{
oilist(`if` `(` $if_expr `)`
|`final` `(` $final_expr `)`
@@ -1037,6 +1063,10 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const TaskloopClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`if` `(` $if_expr `)`
|`final` `(` $final_expr `)`
@@ -1106,6 +1136,10 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", [AttrSizedOperandSegments,
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const TaskgroupClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`task_reduction` `(`
custom<ReductionVarList>(
@@ -1432,6 +1466,10 @@ def TargetDataOp: OpenMP_Op<"target_data", [AttrSizedOperandSegments,
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const TargetDataClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`if` `(` $if_expr `:` type($if_expr) `)`
| `device` `(` $device `:` type($device) `)`
@@ -1486,6 +1524,10 @@ def TargetEnterDataOp: OpenMP_Op<"target_enter_data",
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);
+ let builders = [
+ OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`if` `(` $if_expr `:` type($if_expr) `)`
| `device` `(` $device `:` type($device) `)`
@@ -1540,6 +1582,10 @@ def TargetExitDataOp: OpenMP_Op<"target_exit_data",
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);
+ let builders = [
+ OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`if` `(` $if_expr `:` type($if_expr) `)`
| `device` `(` $device `:` type($device) `)`
@@ -1596,6 +1642,10 @@ def TargetUpdateOp: OpenMP_Op<"target_update", [AttrSizedOperandSegments,
UnitAttr:$nowait,
Variadic<OpenMP_PointerLikeType>:$map_operands);
+ let builders = [
+ OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`if` `(` $if_expr `:` type($if_expr) `)`
| `device` `(` $device `:` type($device) `)`
@@ -1649,6 +1699,10 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const TargetClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist( `if` `(` $if_expr `)`
| `device` `(` $device `:` type($device) `)`
@@ -1693,6 +1747,10 @@ def CriticalDeclareOp : OpenMP_Op<"critical.declare", [Symbol]> {
let arguments = (ins SymbolNameAttr:$sym_name,
DefaultValuedAttr<I64Attr, "0">:$hint_val);
+ let builders = [
+ OpBuilder<(ins CArg<"const CriticalClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
$sym_name oilist(`hint` `(` custom<SynchronizationHint>($hint_val) `)`)
attr-dict
@@ -1773,6 +1831,10 @@ def OrderedOp : OpenMP_Op<"ordered"> {
ConfinedAttr<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$num_loops_val,
Variadic<AnyType>:$depend_vec_vars);
+ let builders = [
+ OpBuilder<(ins CArg<"const OrderedOpClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
( `depend_type` `` $depend_type_val^ )?
( `depend_vec` `(` $depend_vec_vars^ `:` type($depend_vec_vars) `)` )?
@@ -1797,6 +1859,10 @@ def OrderedRegionOp : OpenMP_Op<"ordered.region"> {
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const OrderedRegionClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{ ( `simd` $simd^ )? $region attr-dict}];
let hasVerifier = 1;
}
@@ -1812,6 +1878,10 @@ def TaskwaitOp : OpenMP_Op<"taskwait"> {
of the current task.
}];
+ let builders = [
+ OpBuilder<(ins CArg<"const TaskwaitClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = "attr-dict";
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index bf5875071e0dc4..28869c1ddfb3fd 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -41,6 +41,11 @@
using namespace mlir;
using namespace mlir::omp;
+static ArrayAttr makeArrayAttr(MLIRContext *context,
+ llvm::ArrayRef<Attribute> attrs) {
+ return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
+}
+
namespace {
struct MemRefPointerLikeModel
: public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
@@ -1161,6 +1166,17 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
return success();
}
+//===----------------------------------------------------------------------===//
+// TargetDataOp
+//===----------------------------------------------------------------------===//
+
+void TargetDataOp::build(OpBuilder &builder, OperationState &state,
+ const TargetDataClauseOps &clauses) {
+ TargetDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
+ clauses.useDevicePtrVars, clauses.useDeviceAddrVars,
+ clauses.mapVars);
+}
+
LogicalResult TargetDataOp::verify() {
if (getMapOperands().empty() && getUseDevicePtr().empty() &&
getUseDeviceAddr().empty()) {
@@ -1170,6 +1186,20 @@ LogicalResult TargetDataOp::verify() {
return verifyMapClause(*this, getMapOperands());
}
+//===----------------------------------------------------------------------===//
+// TargetEnterDataOp
+//===----------------------------------------------------------------------===//
+
+void TargetEnterDataOp::build(
+ OpBuilder &builder, OperationState &state,
+ const TargetEnterExitUpdateDataClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ TargetEnterDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
+ makeArrayAttr(ctx, clauses.dependTypeAttrs),
+ clauses.dependVars, clauses.nowaitAttr,
+ clauses.mapVars);
+}
+
LogicalResult TargetEnterDataOp::verify() {
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
@@ -1177,6 +1207,20 @@ LogicalResult TargetEnterDataOp::verify() {
: verifyMapClause(*this, getMapOperands());
}
+//===----------------------------------------------------------------------===//
+// TargetExitDataOp
+//===----------------------------------------------------------------------===//
+
+void TargetExitDataOp::build(
+ OpBuilder &builder, OperationState &state,
+ const TargetEnterExitUpdateDataClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ TargetExitDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
+ makeArrayAttr(ctx, clauses.dependTypeAttrs),
+ clauses.dependVars, clauses.nowaitAttr,
+ clauses.mapVars);
+}
+
LogicalResult TargetExitDataOp::verify() {
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
@@ -1184,6 +1228,19 @@ LogicalResult TargetExitDataOp::verify() {
: verifyMapClause(*this, getMapOperands());
}
+//===----------------------------------------------------------------------===//
+// TargetUpdateOp
+//===----------------------------------------------------------------------===//
+
+void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
+ const TargetEnterExitUpdateDataClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ TargetUpdateOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
+ makeArrayAttr(ctx, clauses.dependTypeAttrs),
+ clauses.dependVars, clauses.nowaitAttr,
+ clauses.mapVars);
+}
+
LogicalResult TargetUpdateOp::verify() {
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
@@ -1191,6 +1248,22 @@ LogicalResult TargetUpdateOp::verify() {
: verifyMapClause(*this, getMapOperands());
}
+//===----------------------------------------------------------------------===//
+// TargetOp
+//===----------------------------------------------------------------------===//
+
+void TargetOp::build(OpBuilder &builder, OperationState &state,
+ const TargetClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
+ // inReductionDeclSymbols, privateVars, privatizers, reductionVars,
+ // reductionByRefAttr, reductionDeclSymbols.
+ TargetOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
+ clauses.threadLimitVar,
+ makeArrayAttr(ctx, clauses.dependTypeAttrs),
+ clauses.dependVars, clauses.nowaitAttr, clauses.mapVars);
+}
+
LogicalResult TargetOp::verify() {
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
@@ -1213,6 +1286,17 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
state.addAttributes(attributes);
}
+void ParallelOp::build(OpBuilder &builder, OperationState &state,
+ const ParallelClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ ParallelOp::build(
+ builder, state, clauses.ifVar, clauses.numThreadsVar,
+ clauses.allocateVars, clauses.allocatorVars, clauses.reductionVars,
+ makeArrayAttr(ctx, clauses.reductionDeclSymbols),
+ clauses.procBindKindAttr, clauses.privateVars,
+ makeArrayAttr(ctx, clauses.privatizers), clauses.reductionByRefAttr);
+}
+
template <typename OpType>
static LogicalResult verifyPrivateVarList(OpType &op) {
auto privateVars = op.getPrivateVars();
@@ -1280,6 +1364,17 @@ static bool opInGlobalImplicitParallelRegion(Operation *op) {
return true;
}
+void TeamsOp::build(OpBuilder &builder, OperationState &state,
+ const TeamsClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
+ TeamsOp::build(builder, state, clauses.numTeamsLowerVar,
+ clauses.numTeamsUpperVar, clauses.ifVar,
+ clauses.threadLimitVar, clauses.allocateVars,
+ clauses.allocatorVars, clauses.reductionVars,
+ makeArrayAttr(ctx, clauses.reductionDeclSymbols));
+}
+
LogicalResult TeamsOp::verify() {
// Check parent region
// TODO If nested inside of a target region, also check that it does not
@@ -1312,9 +1407,19 @@ LogicalResult TeamsOp::verify() {
}
//===----------------------------------------------------------------------===//
-// Verifier for SectionsOp
+// SectionsOp
//===----------------------------------------------------------------------===//
+void SectionsOp::build(OpBuilder &builder, OperationState &state,
+ const SectionsClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
+ SectionsOp::build(builder, state, clauses.reductionVars,
+ makeArrayAttr(ctx, clauses.reductionDeclSymbols),
+ clauses.allocateVars, clauses.allocatorVars,
+ clauses.nowaitAttr);
+}
+
LogicalResult SectionsOp::verify() {
if (getAllocateVars().size() != getAllocatorsVars().size())
return emitError(
@@ -1334,6 +1439,20 @@ LogicalResult SectionsOp::verifyRegions() {
return success();
}
+//===----------------------------------------------------------------------===//
+// SingleOp
+//===----------------------------------------------------------------------===//
+
+void SingleOp::build(OpBuilder &builder, OperationState &state,
+ const SingleClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: privateVars, privatizers.
+ SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.copyprivateVars,
+ makeArrayAttr(ctx, clauses.copyprivateFuncs),
+ clauses.nowaitAttr);
+}
+
LogicalResult SingleOp::verify() {
// Check for allocate clause restrictions
if (getAllocateVars().size() != getAllocatorsVars().size())
@@ -1481,9 +1600,21 @@ void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion,
}
//===----------------------------------------------------------------------===//
-// Verifier for Simd construct [2.9.3.1]
+// Simd construct [2.9.3.1]
//===----------------------------------------------------------------------===//
+void SimdLoopOp::build(OpBuilder &builder, OperationState &state,
+ const SimdLoopClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: privateVars, reductionByRefAttr, reductionVars,
+ // privatizers, reductionDeclSymbols.
+ SimdLoopOp::build(
+ builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar,
+ clauses.alignedVars, makeArrayAttr(ctx, clauses.alignmentAttrs),
+ clauses.ifVar, clauses.nontemporalVars, clauses.orderAttr,
+ clauses.simdlenAttr, clauses.safelenAttr, clauses.loopInclusiveAttr);
+}
+
LogicalResult SimdLoopOp::verify() {
if (this->getLowerBound().empty()) {
return emitOpError() << "empty lowerbound for simd loop operation";
@@ -1504,9 +1635,17 @@ LogicalResult SimdLoopOp::verify() {
}
//===----------------------------------------------------------------------===//
-// Verifier for Distribute construct [2.9.4.1]
+// Distribute construct [2.9.4.1]
//===----------------------------------------------------------------------===//
+void DistributeOp::build(OpBuilder &builder, OperationState &state,
+ const DistributeClauseOps &clauses) {
+ // TODO Store clauses in op: privateVars, privatizers.
+ DistributeOp::build(builder, state, clauses.distScheduleStaticAttr,
+ clauses.distScheduleChunkSizeVar, clauses.allocateVars,
+ clauses.allocatorVars, clauses.orderAttr);
+}
+
LogicalResult DistributeOp::verify() {
if (this->getChunkSize() && !this->getDistScheduleStatic())
return emitOpError() << "chunk size set without "
@@ -1607,6 +1746,19 @@ LogicalResult ReductionOp::verify() {
//===----------------------------------------------------------------------===//
// TaskOp
//===----------------------------------------------------------------------===//
+
+void TaskOp::build(OpBuilder &builder, OperationState &state,
+ const TaskClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: privateVars, privatizers.
+ TaskOp::build(
+ builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
+ clauses.mergeableAttr, clauses.inReductionVars,
+ makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.priorityVar,
+ makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
+ clauses.allocateVars, clauses.allocatorVars);
+}
+
LogicalResult TaskOp::verify() {
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
@@ -1619,6 +1771,15 @@ LogicalResult TaskOp::verify() {
//===----------------------------------------------------------------------===//
// TaskgroupOp
//===----------------------------------------------------------------------===//
+
+void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
+ const TaskgroupClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ TaskgroupOp::build(builder, state, clauses.taskReductionVars,
+ makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
+ clauses.allocateVars, clauses.allocatorVars);
+}
+
LogicalResult TaskgroupOp::verify() {
return verifyReductionVarList(*this, getTaskReductions(),
getTaskReductionVars());
@@ -1627,6 +1788,21 @@ LogicalResult TaskgroupOp::verify() {
//===----------------------------------------------------------------------===//
// TaskloopOp
//===----------------------------------------------------------------------===//
+
+void TaskloopOp::build(OpBuilder &builder, OperationState &state,
+ const TaskloopClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
+ TaskloopOp::build(
+ builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar,
+ clauses.loopInclusiveAttr, clauses.ifVar, clauses.finalVar,
+ clauses.untiedAttr, clauses.mergeableAttr, clauses.inReductionVars,
+ makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars,
+ makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar,
+ clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar,
+ clauses.numTasksVar, clauses.nogroupAttr);
+}
+
SmallVector<Value> TaskloopOp::getAllReductionVars() {
SmallVector<Value> allReductionNvars(getInReductionVars().begin(),
getInReductionVars().end());
@@ -1680,14 +1856,33 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
state.addAttributes(attributes);
}
+void WsloopOp::build(OpBuilder &builder, OperationState &state,
+ const WsloopClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: allocateVars, allocatorVars, privateVars,
+ // privatizers.
+ WsloopOp::build(
+ builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar,
+ clauses.linearVars, clauses.linearStepVars, clauses.reductionVars,
+ makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.scheduleValAttr,
+ clauses.scheduleChunkVar, clauses.scheduleModAttr,
+ clauses.scheduleSimdAttr, clauses.nowaitAttr, clauses.reductionByRefAttr,
+ clauses.orderedAttr, clauses.orderAttr, clauses.loopInclusiveAttr);
+}
+
LogicalResult WsloopOp::verify() {
return verifyReductionVarList(*this, getReductions(), getReductionVars());
}
//===----------------------------------------------------------------------===//
-// Verifier for critical construct (2.17.1)
+// Critical construct (2.17.1)
//===----------------------------------------------------------------------===//
+void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
+ const CriticalClauseOps &clauses) {
+ CriticalDeclareOp::build(builder, state, clauses.nameAttr, clauses.hintAttr);
+}
+
LogicalResult CriticalDeclareOp::verify() {
return verifySynchronizationHint(*this, getHintVal());
}
@@ -1707,9 +1902,15 @@ LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
//===----------------------------------------------------------------------===//
-// Verifier for ordered construct
+// Ordered construct
//===----------------------------------------------------------------------===//
+void OrderedOp::build(OpBuilder &builder, OperationState &state,
+ const OrderedOpClauseOps &clauses) {
+ OrderedOp::build(builder, state, clauses.doacrossDependTypeAttr,
+ clauses.doacrossNumLoopsAttr, clauses.doacrossVectorVars);
+}
+
LogicalResult OrderedOp::verify() {
auto container = (*this)->getParentOfType<WsloopOp>();
if (!container || !container.getOrderedValAttr() ||
@@ -1726,6 +1927,11 @@ LogicalResult OrderedOp::verify() {
return success();
}
+void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
+ const OrderedRegionClauseOps &clauses) {
+ OrderedRegionOp::build(builder, state, clauses.parLevelSimdAttr);
+}
+
LogicalResult OrderedRegionOp::verify() {
// TODO: The code generation for ordered simd directive is not supported yet.
if (getSimd())
@@ -1742,6 +1948,16 @@ LogicalResult OrderedRegionOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// TaskwaitOp
+//===----------------------------------------------------------------------===//
+
+void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
+ const TaskwaitClauseOps &clauses) {
+ // TODO Store clauses in op: dependTypeAttrs, dependVars, nowaitAttr.
+ TaskwaitOp::build(builder, state);
+}
+
//===----------------------------------------------------------------------===//
// Verifier for AtomicReadOp
//===----------------------------------------------------------------------===//
>From 7af7e9d13fc2134e76bb532bfa4313aa3df17924 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 26 Mar 2024 16:46:56 +0000
Subject: [PATCH 2/2] [Flang][OpenMP][Lower] Use clause operand structures
This patch updates Flang lowering to use the new set of OpenMP clause operand
structures and their groupings into directive-specific sets of clause operands.
It simplifies the passing of information from the clause processor and the
creation of operations.
The `DataSharingProcessor` is slightly modified to not hold delayed
privatization state. Instead, optional arguments are added to `processStep1`
which are only passed when delayed privatization is used. This enables using
the clause operand structure for `private` and removes the need for the ad-hoc
`DelayedPrivatizationInfo` structure.
The processing of the `schedule` clause is updated to process the `chunk`
modifier rather than requiring two separate calls to the `ClauseProcessor`.
Lowering of a block-associated `ordered` construct is updated to emit a TODO
error if the `simd` clause is specified, since it is not currently supported by
the `ClauseProcessor` or later compilation stages.
Removed processing of `schedule` from `omp.simdloop`, as it doesn't apply to
`simd` constructs.
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 261 +++++----
flang/lib/Lower/OpenMP/ClauseProcessor.h | 105 ++--
.../lib/Lower/OpenMP/DataSharingProcessor.cpp | 38 +-
flang/lib/Lower/OpenMP/DataSharingProcessor.h | 45 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 517 +++++++-----------
5 files changed, 428 insertions(+), 538 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 0a57a1496289f4..ee1f6c2fbc7e89 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -162,14 +162,13 @@ getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
ifVal);
}
-static void
-addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
- const omp::ObjectList &objects,
- llvm::SmallVectorImpl<mlir::Value> &operands,
- llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- &useDeviceSymbols) {
+static void addUseDeviceClause(
+ Fortran::lower::AbstractConverter &converter,
+ const omp::ObjectList &objects,
+ llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
+ llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms) {
genObjectList(objects, converter, operands);
for (mlir::Value &operand : operands) {
checkMapType(operand.getLoc(), operand.getType());
@@ -177,25 +176,24 @@ addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
useDeviceLocs.push_back(operand.getLoc());
}
for (const omp::Object &object : objects)
- useDeviceSymbols.push_back(object.id());
+ useDeviceSyms.push_back(object.id());
}
static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
mlir::Location loc,
- llvm::SmallVectorImpl<mlir::Value> &lowerBound,
- llvm::SmallVectorImpl<mlir::Value> &upperBound,
- llvm::SmallVectorImpl<mlir::Value> &step,
+ mlir::omp::CollapseClauseOps &result,
std::size_t loopVarTypeSize) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
// The types of lower bound, upper bound, and step are converted into the
// type of the loop variable if necessary.
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
- for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) {
- lowerBound[it] =
- firOpBuilder.createConvert(loc, loopVarType, lowerBound[it]);
- upperBound[it] =
- firOpBuilder.createConvert(loc, loopVarType, upperBound[it]);
- step[it] = firOpBuilder.createConvert(loc, loopVarType, step[it]);
+ for (unsigned it = 0; it < (unsigned)result.loopLBVar.size(); it++) {
+ result.loopLBVar[it] =
+ firOpBuilder.createConvert(loc, loopVarType, result.loopLBVar[it]);
+ result.loopUBVar[it] =
+ firOpBuilder.createConvert(loc, loopVarType, result.loopUBVar[it]);
+ result.loopStepVar[it] =
+ firOpBuilder.createConvert(loc, loopVarType, result.loopStepVar[it]);
}
}
@@ -205,9 +203,7 @@ static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
bool ClauseProcessor::processCollapse(
mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
- llvm::SmallVectorImpl<mlir::Value> &lowerBound,
- llvm::SmallVectorImpl<mlir::Value> &upperBound,
- llvm::SmallVectorImpl<mlir::Value> &step,
+ mlir::omp::CollapseClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const {
bool found = false;
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -238,15 +234,15 @@ bool ClauseProcessor::processCollapse(
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
assert(bounds && "Expected bounds for worksharing do loop");
Fortran::lower::StatementContext stmtCtx;
- lowerBound.push_back(fir::getBase(converter.genExprValue(
+ result.loopLBVar.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
- upperBound.push_back(fir::getBase(converter.genExprValue(
+ result.loopUBVar.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
if (bounds->step) {
- step.push_back(fir::getBase(converter.genExprValue(
+ result.loopStepVar.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
} else { // If `step` is not present, assume it as `1`.
- step.push_back(firOpBuilder.createIntegerConstant(
+ result.loopStepVar.push_back(firOpBuilder.createIntegerConstant(
currentLocation, firOpBuilder.getIntegerType(32), 1));
}
iv.push_back(bounds->name.thing.symbol);
@@ -257,8 +253,7 @@ bool ClauseProcessor::processCollapse(
&*std::next(doConstructEval->getNestedEvaluations().begin());
} while (collapseValue > 0);
- convertLoopBounds(converter, currentLocation, lowerBound, upperBound, step,
- loopVarTypeSize);
+ convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
return found;
}
@@ -286,7 +281,7 @@ bool ClauseProcessor::processDefault() const {
}
bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+ mlir::omp::DeviceClauseOps &result) const {
const Fortran::parser::CharBlock *source = nullptr;
if (auto *clause = findUniqueClause<omp::clause::Device>(&source)) {
mlir::Location clauseLocation = converter.genLocation(*source);
@@ -298,25 +293,26 @@ bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
}
}
const auto &deviceExpr = std::get<omp::SomeExpr>(clause->t);
- result = fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
+ result.deviceVar =
+ fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processDeviceType(
- mlir::omp::DeclareTargetDeviceType &result) const {
+ mlir::omp::DeviceTypeClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::DeviceType>()) {
// Case: declare target ... device_type(any | host | nohost)
switch (clause->v) {
case omp::clause::DeviceType::DeviceTypeDescription::Nohost:
- result = mlir::omp::DeclareTargetDeviceType::nohost;
+ result.deviceType = mlir::omp::DeclareTargetDeviceType::nohost;
break;
case omp::clause::DeviceType::DeviceTypeDescription::Host:
- result = mlir::omp::DeclareTargetDeviceType::host;
+ result.deviceType = mlir::omp::DeclareTargetDeviceType::host;
break;
case omp::clause::DeviceType::DeviceTypeDescription::Any:
- result = mlir::omp::DeclareTargetDeviceType::any;
+ result.deviceType = mlir::omp::DeclareTargetDeviceType::any;
break;
}
return true;
@@ -325,7 +321,7 @@ bool ClauseProcessor::processDeviceType(
}
bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+ mlir::omp::FinalClauseOps &result) const {
const Fortran::parser::CharBlock *source = nullptr;
if (auto *clause = findUniqueClause<omp::clause::Final>(&source)) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -333,100 +329,108 @@ bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
mlir::Value finalVal =
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
- result = firOpBuilder.createConvert(clauseLocation,
- firOpBuilder.getI1Type(), finalVal);
+ result.finalVar = firOpBuilder.createConvert(
+ clauseLocation, firOpBuilder.getI1Type(), finalVal);
return true;
}
return false;
}
-bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Hint>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
int64_t hintValue = *Fortran::evaluate::ToInt64(clause->v);
- result = firOpBuilder.getI64IntegerAttr(hintValue);
+ result.hintAttr = firOpBuilder.getI64IntegerAttr(hintValue);
return true;
}
return false;
}
-bool ClauseProcessor::processMergeable(mlir::UnitAttr &result) const {
- return markClauseOccurrence<omp::clause::Mergeable>(result);
+bool ClauseProcessor::processMergeable(
+ mlir::omp::MergeableClauseOps &result) const {
+ return markClauseOccurrence<omp::clause::Mergeable>(result.mergeableAttr);
}
-bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const {
- return markClauseOccurrence<omp::clause::Nowait>(result);
+bool ClauseProcessor::processNowait(mlir::omp::NowaitClauseOps &result) const {
+ return markClauseOccurrence<omp::clause::Nowait>(result.nowaitAttr);
}
-bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+bool ClauseProcessor::processNumTeams(
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::NumTeamsClauseOps &result) const {
// TODO Get lower and upper bounds for num_teams when parser is updated to
// accept both.
if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) {
// auto lowerBound = std::get<std::optional<ExprTy>>(clause->t);
auto &upperBound = std::get<ExprTy>(clause->t);
- result = fir::getBase(converter.genExprValue(upperBound, stmtCtx));
+ result.numTeamsUpperVar =
+ fir::getBase(converter.genExprValue(upperBound, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processNumThreads(
- Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::NumThreadsClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
// OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
- result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.numThreadsVar =
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
-bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processOrdered(
+ mlir::omp::OrderedClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Ordered>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
int64_t orderedClauseValue = 0l;
if (clause->v.has_value())
orderedClauseValue = *Fortran::evaluate::ToInt64(*clause->v);
- result = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
+ result.orderedAttr = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
return true;
}
return false;
}
-bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+bool ClauseProcessor::processPriority(
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::PriorityClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Priority>()) {
- result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.priorityVar =
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processProcBind(
- mlir::omp::ClauseProcBindKindAttr &result) const {
+ mlir::omp::ProcBindClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::ProcBind>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- result = genProcBindKindAttr(firOpBuilder, *clause);
+ result.procBindKindAttr = genProcBindKindAttr(firOpBuilder, *clause);
return true;
}
return false;
}
-bool ClauseProcessor::processSafelen(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processSafelen(
+ mlir::omp::SafelenClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Safelen>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
const std::optional<std::int64_t> safelenVal =
Fortran::evaluate::ToInt64(clause->v);
- result = firOpBuilder.getI64IntegerAttr(*safelenVal);
+ result.safelenAttr = firOpBuilder.getI64IntegerAttr(*safelenVal);
return true;
}
return false;
}
bool ClauseProcessor::processSchedule(
- mlir::omp::ClauseScheduleKindAttr &valAttr,
- mlir::omp::ScheduleModifierAttr &modifierAttr,
- mlir::UnitAttr &simdModifierAttr) const {
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::ScheduleClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Schedule>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::MLIRContext *context = firOpBuilder.getContext();
@@ -451,53 +455,51 @@ bool ClauseProcessor::processSchedule(
break;
}
- mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause);
+ result.scheduleValAttr =
+ mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
+ mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause);
if (scheduleModifier != mlir::omp::ScheduleModifier::none)
- modifierAttr =
+ result.scheduleModAttr =
mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier);
if (getSimdModifier(*clause) != mlir::omp::ScheduleModifier::none)
- simdModifierAttr = firOpBuilder.getUnitAttr();
+ result.scheduleSimdAttr = firOpBuilder.getUnitAttr();
- valAttr = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processScheduleChunk(
- Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
- if (auto *clause = findUniqueClause<omp::clause::Schedule>()) {
if (const auto &chunkExpr = std::get<omp::MaybeExpr>(clause->t))
- result = fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
+ result.scheduleChunkVar =
+ fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
+
return true;
}
return false;
}
-bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processSimdlen(
+ mlir::omp::SimdlenClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Simdlen>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
const std::optional<std::int64_t> simdlenVal =
Fortran::evaluate::ToInt64(clause->v);
- result = firOpBuilder.getI64IntegerAttr(*simdlenVal);
+ result.simdlenAttr = firOpBuilder.getI64IntegerAttr(*simdlenVal);
return true;
}
return false;
}
bool ClauseProcessor::processThreadLimit(
- Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::ThreadLimitClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) {
- result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.threadLimitVar =
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
-bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
- return markClauseOccurrence<omp::clause::Untied>(result);
+bool ClauseProcessor::processUntied(mlir::omp::UntiedClauseOps &result) const {
+ return markClauseOccurrence<omp::clause::Untied>(result.untiedAttr);
}
//===----------------------------------------------------------------------===//
@@ -505,13 +507,12 @@ bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
//===----------------------------------------------------------------------===//
bool ClauseProcessor::processAllocate(
- llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
- llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const {
+ mlir::omp::AllocateClauseOps &result) const {
return findRepeatableClause<omp::clause::Allocate>(
[&](const omp::clause::Allocate &clause,
const Fortran::parser::CharBlock &) {
- genAllocateClause(converter, clause, allocatorOperands,
- allocateOperands);
+ genAllocateClause(converter, clause, result.allocatorVars,
+ result.allocateVars);
});
}
@@ -660,10 +661,9 @@ createCopyFunc(mlir::Location loc, Fortran::lower::AbstractConverter &converter,
return funcOp;
}
-bool ClauseProcessor::processCopyPrivate(
+bool ClauseProcessor::processCopyprivate(
mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> ©PrivateVars,
- llvm::SmallVectorImpl<mlir::Attribute> ©PrivateFuncs) const {
+ mlir::omp::CopyprivateClauseOps &result) const {
auto addCopyPrivateVar = [&](Fortran::semantics::Symbol *sym) {
mlir::Value symVal = converter.getSymbolAddress(*sym);
auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>();
@@ -690,10 +690,10 @@ bool ClauseProcessor::processCopyPrivate(
cpVar = alloca;
}
- copyPrivateVars.push_back(cpVar);
+ result.copyprivateVars.push_back(cpVar);
mlir::func::FuncOp funcOp =
createCopyFunc(currentLocation, converter, cpVar.getType(), attrs);
- copyPrivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
+ result.copyprivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
};
bool hasCopyPrivate = findRepeatableClause<clause::Copyprivate>(
@@ -714,9 +714,7 @@ bool ClauseProcessor::processCopyPrivate(
return hasCopyPrivate;
}
-bool ClauseProcessor::processDepend(
- llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
- llvm::SmallVectorImpl<mlir::Value> &dependOperands) const {
+bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
return findRepeatableClause<omp::clause::Depend>(
@@ -731,7 +729,7 @@ bool ClauseProcessor::processDepend(
mlir::omp::ClauseTaskDependAttr dependTypeOperand =
genDependKindAttr(firOpBuilder, kind);
- dependTypeOperands.append(objects.size(), dependTypeOperand);
+ result.dependTypeAttrs.append(objects.size(), dependTypeOperand);
for (const omp::Object &object : objects) {
assert(object.ref() && "Expecting designator");
@@ -746,14 +744,14 @@ bool ClauseProcessor::processDepend(
Fortran::semantics::Symbol *sym = object.id();
const mlir::Value variable = converter.getSymbolAddress(*sym);
- dependOperands.push_back(variable);
+ result.dependVars.push_back(variable);
}
});
}
bool ClauseProcessor::processIf(
omp::clause::If::DirectiveNameModifier directiveName,
- mlir::Value &result) const {
+ mlir::omp::IfClauseOps &result) const {
bool found = false;
findRepeatableClause<omp::clause::If>(
[&](const omp::clause::If &clause,
@@ -764,7 +762,7 @@ bool ClauseProcessor::processIf(
// Assume that, at most, a single 'if' clause will be applicable to the
// given directive.
if (operand) {
- result = operand;
+ result.ifVar = operand;
found = true;
}
});
@@ -807,12 +805,10 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
bool ClauseProcessor::processMap(
mlir::Location currentLocation, const llvm::omp::Directive &directive,
- Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands,
- llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
+ Fortran::lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
- const {
+ llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
return findRepeatableClause<omp::clause::Map>(
[&](const omp::clause::Map &clause,
@@ -887,25 +883,23 @@ bool ClauseProcessor::processMap(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
- mapOperands.push_back(mapOp);
- if (mapSymTypes)
- mapSymTypes->push_back(symAddr.getType());
+ result.mapVars.push_back(mapOp);
+
+ if (mapSyms)
+ mapSyms->push_back(object.id());
if (mapSymLocs)
mapSymLocs->push_back(symAddr.getLoc());
-
- if (mapSymbols)
- mapSymbols->push_back(object.id());
+ if (mapSymTypes)
+ mapSymTypes->push_back(symAddr.getType());
}
});
}
bool ClauseProcessor::processReduction(
- mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> &outReductionVars,
- llvm::SmallVectorImpl<mlir::Type> &outReductionTypes,
- llvm::SmallVectorImpl<mlir::Attribute> &outReductionDeclSymbols,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- *outReductionSymbols) const {
+ mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
+ llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *outReductionSyms)
+ const {
return findRepeatableClause<omp::clause::Reduction>(
[&](const omp::clause::Reduction &clause,
const Fortran::parser::CharBlock &) {
@@ -915,30 +909,31 @@ bool ClauseProcessor::processReduction(
// whether to do the reduction byref.
llvm::SmallVector<mlir::Value> reductionVars;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
- llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction(currentLocation, converter, clause,
reductionVars, reductionDeclSymbols,
- outReductionSymbols ? &reductionSymbols
- : nullptr);
+ outReductionSyms ? &reductionSyms : nullptr);
// Copy local lists into the output.
- llvm::copy(reductionVars, std::back_inserter(outReductionVars));
+ llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reductionDeclSymbols,
- std::back_inserter(outReductionDeclSymbols));
- if (outReductionSymbols)
- llvm::copy(reductionSymbols,
- std::back_inserter(*outReductionSymbols));
-
- outReductionTypes.reserve(outReductionTypes.size() +
- reductionVars.size());
- llvm::transform(reductionVars, std::back_inserter(outReductionTypes),
- [](mlir::Value v) { return v.getType(); });
+ std::back_inserter(result.reductionDeclSymbols));
+
+ if (outReductionTypes) {
+ outReductionTypes->reserve(outReductionTypes->size() +
+ reductionVars.size());
+ llvm::transform(reductionVars, std::back_inserter(*outReductionTypes),
+ [](mlir::Value v) { return v.getType(); });
+ }
+
+ if (outReductionSyms)
+ llvm::copy(reductionSyms, std::back_inserter(*outReductionSyms));
});
}
bool ClauseProcessor::processSectionsReduction(
- mlir::Location currentLocation) const {
+ mlir::Location currentLocation, mlir::omp::ReductionClauseOps &) const {
return findRepeatableClause<omp::clause::Reduction>(
[&](const omp::clause::Reduction &, const Fortran::parser::CharBlock &) {
TODO(currentLocation, "OMPC_Reduction");
@@ -967,30 +962,30 @@ bool ClauseProcessor::processEnter(
}
bool ClauseProcessor::processUseDeviceAddr(
- llvm::SmallVectorImpl<mlir::Value> &operands,
+ mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
const {
return findRepeatableClause<omp::clause::UseDeviceAddr>(
[&](const omp::clause::UseDeviceAddr &clause,
const Fortran::parser::CharBlock &) {
- addUseDeviceClause(converter, clause.v, operands, useDeviceTypes,
- useDeviceLocs, useDeviceSymbols);
+ addUseDeviceClause(converter, clause.v, result.useDeviceAddrVars,
+ useDeviceTypes, useDeviceLocs, useDeviceSyms);
});
}
bool ClauseProcessor::processUseDevicePtr(
- llvm::SmallVectorImpl<mlir::Value> &operands,
+ mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
const {
return findRepeatableClause<omp::clause::UseDevicePtr>(
[&](const omp::clause::UseDevicePtr &clause,
const Fortran::parser::CharBlock &) {
- addUseDeviceClause(converter, clause.v, operands, useDeviceTypes,
- useDeviceLocs, useDeviceSymbols);
+ addUseDeviceClause(converter, clause.v, result.useDevicePtrVars,
+ useDeviceTypes, useDeviceLocs, useDeviceSyms);
});
}
} // namespace omp
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index c0c603feb296af..d933e0a913d2bc 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -37,7 +37,7 @@ namespace omp {
/// corresponding clause if it is present in the clause list. Otherwise, they
/// will return `false` to signal that the clause was not found.
///
-/// The intended use is of this class is to move clause processing outside of
+/// The intended use of this class is to move clause processing outside of
/// construct processing, since the same clauses can appear attached to
/// different constructs and constructs can be combined, so that code
/// duplication is minimized.
@@ -56,94 +56,83 @@ class ClauseProcessor {
// 'Unique' clauses: They can appear at most once in the clause list.
bool processCollapse(
mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
- llvm::SmallVectorImpl<mlir::Value> &lowerBound,
- llvm::SmallVectorImpl<mlir::Value> &upperBound,
- llvm::SmallVectorImpl<mlir::Value> &step,
+ mlir::omp::CollapseClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const;
bool processDefault() const;
bool processDevice(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const;
+ mlir::omp::DeviceClauseOps &result) const;
+ bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
bool processFinal(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processHint(mlir::IntegerAttr &result) const;
- bool processMergeable(mlir::UnitAttr &result) const;
- bool processNowait(mlir::UnitAttr &result) const;
+ mlir::omp::FinalClauseOps &result) const;
+ bool processHint(mlir::omp::HintClauseOps &result) const;
+ bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
+ bool processNowait(mlir::omp::NowaitClauseOps &result) const;
bool processNumTeams(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
+ mlir::omp::NumTeamsClauseOps &result) const;
bool processNumThreads(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processOrdered(mlir::IntegerAttr &result) const;
+ mlir::omp::NumThreadsClauseOps &result) const;
+ bool processOrdered(mlir::omp::OrderedClauseOps &result) const;
bool processPriority(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const;
- bool processSafelen(mlir::IntegerAttr &result) const;
- bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr,
- mlir::omp::ScheduleModifierAttr &modifierAttr,
- mlir::UnitAttr &simdModifierAttr) const;
- bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processSimdlen(mlir::IntegerAttr &result) const;
+ mlir::omp::PriorityClauseOps &result) const;
+ bool processProcBind(mlir::omp::ProcBindClauseOps &result) const;
+ bool processSafelen(mlir::omp::SafelenClauseOps &result) const;
+ bool processSchedule(Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::ScheduleClauseOps &result) const;
+ bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const;
bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processUntied(mlir::UnitAttr &result) const;
+ mlir::omp::ThreadLimitClauseOps &result) const;
+ bool processUntied(mlir::omp::UntiedClauseOps &result) const;
// 'Repeatable' clauses: They can appear multiple times in the clause list.
- bool
- processAllocate(llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
- llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const;
+ bool processAllocate(mlir::omp::AllocateClauseOps &result) const;
bool processCopyin() const;
- bool processCopyPrivate(
- mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> ©PrivateVars,
- llvm::SmallVectorImpl<mlir::Attribute> ©PrivateFuncs) const;
- bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
- llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
+ bool processCopyprivate(mlir::Location currentLocation,
+ mlir::omp::CopyprivateClauseOps &result) const;
+ bool processDepend(mlir::omp::DependClauseOps &result) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
- mlir::Value &result) const;
+ mlir::omp::IfClauseOps &result) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
// This method is used to process a map clause.
- // The optional parameters - mapSymTypes, mapSymLocs & mapSymbols are used to
+ // The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to
// store the original type, location and Fortran symbol for the map operands.
// They may be used later on to create the block_arguments for some of the
// target directives that require it.
- bool processMap(mlir::Location currentLocation,
- const llvm::omp::Directive &directive,
- Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands,
- llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
- llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- *mapSymbols = nullptr) const;
- bool
- processReduction(mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
- llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- *reductionSymbols = nullptr) const;
- bool processSectionsReduction(mlir::Location currentLocation) const;
+ bool processMap(
+ mlir::Location currentLocation, const llvm::omp::Directive &directive,
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::MapClauseOps &result,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms =
+ nullptr,
+ llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
+ llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
+ bool processReduction(
+ mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
+ llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSyms =
+ nullptr) const;
+ bool processSectionsReduction(mlir::Location currentLocation,
+ mlir::omp::ReductionClauseOps &result) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool
- processUseDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
+ processUseDeviceAddr(mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- &useDeviceSymbols) const;
+ &useDeviceSyms) const;
bool
- processUseDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
+ processUseDevicePtr(mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- &useDeviceSymbols) const;
+ &useDeviceSyms) const;
template <typename T>
bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands);
+ mlir::omp::MapClauseOps &result);
// Call this method for these clauses that should be supported but are not
// implemented yet. It triggers a compilation error if any of the given
@@ -185,7 +174,7 @@ class ClauseProcessor {
template <typename T>
bool ClauseProcessor::processMotionClauses(
Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
+ mlir::omp::MapClauseOps &result) {
return findRepeatableClause<T>(
[&](const T &clause, const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
@@ -227,7 +216,7 @@ bool ClauseProcessor::processMotionClauses(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
- mapOperands.push_back(mapOp);
+ result.mapVars.push_back(mapOp);
}
});
}
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index e114ab9f4548ab..5a42e6a6aa4175 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -23,11 +23,13 @@ namespace Fortran {
namespace lower {
namespace omp {
-void DataSharingProcessor::processStep1() {
+void DataSharingProcessor::processStep1(
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
collectSymbolsForPrivatization();
collectDefaultSymbols();
- privatize();
- defaultPrivatize();
+ privatize(clauseOps, privateSyms);
+ defaultPrivatize(clauseOps, privateSyms);
insertBarrier();
}
@@ -299,14 +301,16 @@ void DataSharingProcessor::collectDefaultSymbols() {
}
}
-void DataSharingProcessor::privatize() {
+void DataSharingProcessor::privatize(
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
for (const Fortran::semantics::Symbol *sym : privatizedSymbols) {
if (const auto *commonDet =
sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
for (const auto &mem : commonDet->objects())
- doPrivatize(&*mem);
+ doPrivatize(&*mem, clauseOps, privateSyms);
} else
- doPrivatize(sym);
+ doPrivatize(sym, clauseOps, privateSyms);
}
}
@@ -323,7 +327,9 @@ void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) {
}
}
-void DataSharingProcessor::defaultPrivatize() {
+void DataSharingProcessor::defaultPrivatize(
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
for (const Fortran::semantics::Symbol *sym : defaultSymbols) {
if (!Fortran::semantics::IsProcedure(*sym) &&
!sym->GetUltimate().has<Fortran::semantics::DerivedTypeDetails>() &&
@@ -331,11 +337,14 @@ void DataSharingProcessor::defaultPrivatize() {
!symbolsInNestedRegions.contains(sym) &&
!symbolsInParentRegions.contains(sym) &&
!privatizedSymbols.contains(sym))
- doPrivatize(sym);
+ doPrivatize(sym, clauseOps, privateSyms);
}
}
-void DataSharingProcessor::doPrivatize(const Fortran::semantics::Symbol *sym) {
+void DataSharingProcessor::doPrivatize(
+ const Fortran::semantics::Symbol *sym,
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
if (!useDelayedPrivatization) {
cloneSymbol(sym);
copyFirstPrivateSymbol(sym);
@@ -419,10 +428,13 @@ void DataSharingProcessor::doPrivatize(const Fortran::semantics::Symbol *sym) {
return result;
}();
- delayedPrivatizationInfo.privatizers.push_back(
- mlir::SymbolRefAttr::get(privatizerOp));
- delayedPrivatizationInfo.originalAddresses.push_back(hsb.getAddr());
- delayedPrivatizationInfo.symbols.push_back(sym);
+ if (clauseOps) {
+ clauseOps->privatizers.push_back(mlir::SymbolRefAttr::get(privatizerOp));
+ clauseOps->privateVars.push_back(hsb.getAddr());
+ }
+
+ if (privateSyms)
+ privateSyms->push_back(sym);
}
} // namespace omp
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
index 226abe96705e35..9724b3d5ed02fe 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
@@ -19,28 +19,17 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
+namespace mlir {
+namespace omp {
+struct PrivateClauseOps;
+} // namespace omp
+} // namespace mlir
+
namespace Fortran {
namespace lower {
namespace omp {
class DataSharingProcessor {
-public:
- /// Collects all the information needed for delayed privatization. This can be
- /// used by ops with data-sharing clauses to properly generate their regions
- /// (e.g. add region arguments) and map the original SSA values to their
- /// corresponding OMP region operands.
- struct DelayedPrivatizationInfo {
- // The list of symbols referring to delayed privatizer ops (i.e.
- // `omp.private` ops).
- llvm::SmallVector<mlir::SymbolRefAttr> privatizers;
- // SSA values that correspond to "original" values being privatized.
- // "Original" here means the SSA value outside the OpenMP region from which
- // a clone is created inside the region.
- llvm::SmallVector<mlir::Value> originalAddresses;
- // Fortran symbols corresponding to the above SSA values.
- llvm::SmallVector<const Fortran::semantics::Symbol *> symbols;
- };
-
private:
bool hasLastPrivateOp;
mlir::OpBuilder::InsertPoint lastPrivIP;
@@ -57,7 +46,6 @@ class DataSharingProcessor {
Fortran::lower::pft::Evaluation &eval;
bool useDelayedPrivatization;
Fortran::lower::SymMap *symTable;
- DelayedPrivatizationInfo delayedPrivatizationInfo;
bool needBarrier();
void collectSymbols(Fortran::semantics::Symbol::Flag flag);
@@ -67,9 +55,16 @@ class DataSharingProcessor {
void collectSymbolsForPrivatization();
void insertBarrier();
void collectDefaultSymbols();
- void privatize();
- void defaultPrivatize();
- void doPrivatize(const Fortran::semantics::Symbol *sym);
+ void privatize(
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms);
+ void defaultPrivatize(
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms);
+ void doPrivatize(
+ const Fortran::semantics::Symbol *sym,
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms);
void copyLastPrivatize(mlir::Operation *op);
void insertLastPrivateCompare(mlir::Operation *op);
void cloneSymbol(const Fortran::semantics::Symbol *sym);
@@ -103,17 +98,15 @@ class DataSharingProcessor {
// Step2 performs the copying for lastprivates and requires knowledge of the
// MLIR operation to insert the last private update. Step2 adds
// dealocation code as well.
- void processStep1();
+ void processStep1(mlir::omp::PrivateClauseOps *clauseOps = nullptr,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ *privateSyms = nullptr);
void processStep2(mlir::Operation *op, bool isLoop);
void setLoopIV(mlir::Value iv) {
assert(!loopIV && "Loop iteration variable already set");
loopIV = iv;
}
-
- const DelayedPrivatizationInfo &getDelayedPrivatizationInfo() const {
- return delayedPrivatizationInfo;
- }
};
} // namespace omp
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 0cf2a8f97040a8..d67060d1cce72b 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -523,19 +523,25 @@ genMasterOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation) {
return genOpWithBody<mlir::omp::MasterOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
- .setGenNested(genNested),
- /*resultTypes=*/mlir::TypeRange());
+ .setGenNested(genNested));
}
static mlir::omp::OrderedRegionOp
genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location currentLocation) {
+ mlir::Location currentLocation,
+ const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::omp::OrderedRegionClauseOps clauseOps;
+
+ ClauseProcessor cp(converter, semaCtx, clauseList);
+ cp.processTODO<clause::Simd>(currentLocation,
+ llvm::omp::Directive::OMPD_ordered);
+
return genOpWithBody<mlir::omp::OrderedRegionOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested),
- /*simd=*/false);
+ clauseOps);
}
static mlir::omp::ParallelOp
@@ -546,77 +552,62 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList,
bool outerCombined = false) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
- mlir::Value ifClauseOperand, numThreadsClauseOperand;
- mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
- reductionVars;
+ mlir::omp::ParallelClauseOps clauseOps;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> privateSyms;
llvm::SmallVector<mlir::Type> reductionTypes;
- llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
- llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_parallel, ifClauseOperand);
- cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
- cp.processProcBind(procBindKindAttr);
+ cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
+ cp.processNumThreads(stmtCtx, clauseOps);
+ cp.processProcBind(clauseOps);
cp.processDefault();
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processAllocate(clauseOps);
+
if (!outerCombined)
- cp.processReduction(currentLocation, reductionVars, reductionTypes,
- reductionDeclSymbols, &reductionSymbols);
+ cp.processReduction(currentLocation, clauseOps, &reductionTypes,
+ &reductionSyms);
+
+ if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
+ clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
auto reductionCallback = [&](mlir::Operation *op) {
- llvm::SmallVector<mlir::Location> locs(reductionVars.size(),
+ llvm::SmallVector<mlir::Location> locs(clauseOps.reductionVars.size(),
currentLocation);
- auto *block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {},
- reductionTypes, locs);
+ auto *block =
+ firOpBuilder.createBlock(&op->getRegion(0), {}, reductionTypes, locs);
for (auto [arg, prv] :
- llvm::zip_equal(reductionSymbols, block->getArguments())) {
+ llvm::zip_equal(reductionSyms, block->getArguments())) {
converter.bindSymbol(*arg, prv);
}
- return reductionSymbols;
+ return reductionSyms;
};
- mlir::UnitAttr byrefAttr;
- if (ReductionProcessor::doReductionByRef(reductionVars))
- byrefAttr = converter.getFirOpBuilder().getUnitAttr();
-
OpWithBodyGenInfo genInfo =
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setOuterCombined(outerCombined)
.setClauses(&clauseList)
- .setReductions(&reductionSymbols, &reductionTypes)
+ .setReductions(&reductionSyms, &reductionTypes)
.setGenRegionEntryCb(reductionCallback);
- if (!enableDelayedPrivatization) {
- return genOpWithBody<mlir::omp::ParallelOp>(
- genInfo,
- /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
- numThreadsClauseOperand, allocateOperands, allocatorOperands,
- reductionVars,
- reductionDeclSymbols.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- reductionDeclSymbols),
- procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
- /*privatizers=*/nullptr, byrefAttr);
- }
+ if (!enableDelayedPrivatization)
+ return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
bool privatize = !outerCombined;
DataSharingProcessor dsp(converter, semaCtx, clauseList, eval,
/*useDelayedPrivatization=*/true, &symTable);
if (privatize)
- dsp.processStep1();
-
- const auto &delayedPrivatizationInfo = dsp.getDelayedPrivatizationInfo();
+ dsp.processStep1(&clauseOps, &privateSyms);
auto genRegionEntryCB = [&](mlir::Operation *op) {
auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
- llvm::SmallVector<mlir::Location> reductionLocs(reductionVars.size(),
- currentLocation);
+ llvm::SmallVector<mlir::Location> reductionLocs(
+ clauseOps.reductionVars.size(), currentLocation);
mlir::OperandRange privateVars = parallelOp.getPrivateVars();
mlir::Region ®ion = parallelOp.getRegion();
@@ -631,12 +622,12 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
llvm::transform(privateVars, std::back_inserter(privateVarLocs),
[](mlir::Value v) { return v.getLoc(); });
- converter.getFirOpBuilder().createBlock(®ion, /*insertPt=*/{},
- privateVarTypes, privateVarLocs);
+ firOpBuilder.createBlock(®ion, /*insertPt=*/{}, privateVarTypes,
+ privateVarLocs);
llvm::SmallVector<const Fortran::semantics::Symbol *> allSymbols =
- reductionSymbols;
- allSymbols.append(delayedPrivatizationInfo.symbols);
+ reductionSyms;
+ allSymbols.append(privateSyms);
for (auto [arg, prv] : llvm::zip_equal(allSymbols, region.getArguments())) {
converter.bindSymbol(*arg, prv);
}
@@ -646,26 +637,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
// TODO Merge with the reduction CB.
genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp);
-
- llvm::SmallVector<mlir::Attribute> privatizers(
- delayedPrivatizationInfo.privatizers.begin(),
- delayedPrivatizationInfo.privatizers.end());
-
- return genOpWithBody<mlir::omp::ParallelOp>(
- genInfo,
- /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
- numThreadsClauseOperand, allocateOperands, allocatorOperands,
- reductionVars,
- reductionDeclSymbols.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- reductionDeclSymbols),
- procBindKindAttr, delayedPrivatizationInfo.originalAddresses,
- delayedPrivatizationInfo.privatizers.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- privatizers),
- byrefAttr);
+ return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
}
static mlir::omp::SectionOp
@@ -689,28 +661,21 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &beginClauseList,
const Fortran::parser::OmpClauseList &endClauseList) {
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
- llvm::SmallVector<mlir::Value> copyPrivateVars;
- llvm::SmallVector<mlir::Attribute> copyPrivateFuncs;
- mlir::UnitAttr nowaitAttr;
+ mlir::omp::SingleClauseOps clauseOps;
ClauseProcessor cp(converter, semaCtx, beginClauseList);
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processAllocate(clauseOps);
+ // TODO Support delayed privatization.
ClauseProcessor ecp(converter, semaCtx, endClauseList);
- ecp.processNowait(nowaitAttr);
- ecp.processCopyPrivate(currentLocation, copyPrivateVars, copyPrivateFuncs);
+ ecp.processNowait(clauseOps);
+ ecp.processCopyprivate(currentLocation, clauseOps);
return genOpWithBody<mlir::omp::SingleOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setClauses(&beginClauseList),
- allocateOperands, allocatorOperands, copyPrivateVars,
- copyPrivateFuncs.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- copyPrivateFuncs),
- nowaitAttr);
+ clauseOps);
}
static mlir::omp::TaskOp
@@ -720,21 +685,19 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList) {
Fortran::lower::StatementContext stmtCtx;
- mlir::Value ifClauseOperand, finalClauseOperand, priorityClauseOperand;
- mlir::UnitAttr untiedAttr, mergeableAttr;
- llvm::SmallVector<mlir::Attribute> dependTypeOperands;
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
- dependOperands;
+ mlir::omp::TaskClauseOps clauseOps;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_task, ifClauseOperand);
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
+ cp.processAllocate(clauseOps);
cp.processDefault();
- cp.processFinal(stmtCtx, finalClauseOperand);
- cp.processUntied(untiedAttr);
- cp.processMergeable(mergeableAttr);
- cp.processPriority(stmtCtx, priorityClauseOperand);
- cp.processDepend(dependTypeOperands, dependOperands);
+ cp.processFinal(stmtCtx, clauseOps);
+ cp.processUntied(clauseOps);
+ cp.processMergeable(clauseOps);
+ cp.processPriority(stmtCtx, clauseOps);
+ cp.processDepend(clauseOps);
+ // TODO Support delayed privatization.
+
cp.processTODO<clause::InReduction, clause::Detach, clause::Affinity>(
currentLocation, llvm::omp::Directive::OMPD_task);
@@ -742,14 +705,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setClauses(&clauseList),
- ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr,
- /*in_reduction_vars=*/mlir::ValueRange(),
- /*in_reductions=*/nullptr, priorityClauseOperand,
- dependTypeOperands.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- dependTypeOperands),
- dependOperands, allocateOperands, allocatorOperands);
+ clauseOps);
}
static mlir::omp::TaskgroupOp
@@ -758,17 +714,18 @@ genTaskgroupOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval, bool genNested,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList) {
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
+ mlir::omp::TaskgroupClauseOps clauseOps;
+
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processAllocate(clauseOps);
cp.processTODO<clause::TaskReduction>(currentLocation,
llvm::omp::Directive::OMPD_taskgroup);
+
return genOpWithBody<mlir::omp::TaskgroupOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setClauses(&clauseList),
- /*task_reduction_vars=*/mlir::ValueRange(),
- /*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
+ clauseOps);
}
// This helper function implements the functionality of "promoting"
@@ -789,8 +746,7 @@ genTaskgroupOp(Fortran::lower::AbstractConverter &converter,
// clause. Support for such list items in a use_device_ptr clause
// is deprecated."
static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
- llvm::SmallVectorImpl<mlir::Value> &devicePtrOperands,
- llvm::SmallVectorImpl<mlir::Value> &deviceAddrOperands,
+ mlir::omp::UseDeviceClauseOps &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
@@ -803,9 +759,10 @@ static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
// Iterate over our use_device_ptr list and shift all non-cptr arguments into
// use_device_addr.
- for (auto *it = devicePtrOperands.begin(); it != devicePtrOperands.end();) {
+ for (auto *it = clauseOps.useDevicePtrVars.begin();
+ it != clauseOps.useDevicePtrVars.end();) {
if (!fir::isa_builtin_cptr_type(fir::unwrapRefType(it->getType()))) {
- deviceAddrOperands.push_back(*it);
+ clauseOps.useDeviceAddrVars.push_back(*it);
// We have to shuffle the symbols around as well, to maintain
// the correct Input -> BlockArg for use_device_ptr/use_device_addr.
// NOTE: However, as map's do not seem to be included currently
@@ -813,11 +770,11 @@ static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
// future alterations. I believe the reason they are not currently
// is that the BlockArg assign/lowering needs to be extended
// to a greater set of types.
- auto idx = std::distance(devicePtrOperands.begin(), it);
+ auto idx = std::distance(clauseOps.useDevicePtrVars.begin(), it);
moveElementToBack(idx, useDeviceTypes);
moveElementToBack(idx, useDeviceLocs);
moveElementToBack(idx, useDeviceSymbols);
- it = devicePtrOperands.erase(it);
+ it = clauseOps.useDevicePtrVars.erase(it);
continue;
}
++it;
@@ -831,20 +788,19 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList) {
Fortran::lower::StatementContext stmtCtx;
- mlir::Value ifClauseOperand, deviceOperand;
- llvm::SmallVector<mlir::Value> mapOperands, devicePtrOperands,
- deviceAddrOperands;
+ mlir::omp::TargetDataClauseOps clauseOps;
llvm::SmallVector<mlir::Type> useDeviceTypes;
llvm::SmallVector<mlir::Location> useDeviceLocs;
- llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSyms;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_target_data, ifClauseOperand);
- cp.processDevice(stmtCtx, deviceOperand);
- cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
- useDeviceSymbols);
- cp.processUseDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs,
- useDeviceSymbols);
+ cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps);
+ cp.processDevice(stmtCtx, clauseOps);
+ cp.processUseDevicePtr(clauseOps, useDeviceTypes, useDeviceLocs,
+ useDeviceSyms);
+ cp.processUseDeviceAddr(clauseOps, useDeviceTypes, useDeviceLocs,
+ useDeviceSyms);
+
// This function implements the deprecated functionality of use_device_ptr
// that allows users to provide non-CPTR arguments to it with the caveat
// that the compiler will treat them as use_device_addr. A lot of legacy
@@ -856,17 +812,16 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter,
// ordering.
// TODO: Perhaps create a user provideable compiler option that will
// re-introduce a hard-error rather than a warning in these cases.
- promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
- devicePtrOperands, deviceAddrOperands, useDeviceTypes, useDeviceLocs,
- useDeviceSymbols);
+ promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(clauseOps, useDeviceTypes,
+ useDeviceLocs, useDeviceSyms);
cp.processMap(currentLocation, llvm::omp::Directive::OMPD_target_data,
- stmtCtx, mapOperands);
+ stmtCtx, clauseOps);
auto dataOp = converter.getFirOpBuilder().create<mlir::omp::TargetDataOp>(
- currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
- deviceAddrOperands, mapOperands);
+ currentLocation, clauseOps);
+
genBodyOfTargetDataOp(converter, semaCtx, eval, genNested, dataOp,
- useDeviceTypes, useDeviceLocs, useDeviceSymbols,
+ useDeviceTypes, useDeviceLocs, useDeviceSyms,
currentLocation);
return dataOp;
}
@@ -879,10 +834,7 @@ static OpTy genTargetEnterExitDataUpdateOp(
const Fortran::parser::OmpClauseList &clauseList) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
- mlir::Value ifClauseOperand, deviceOperand;
- mlir::UnitAttr nowaitAttr;
- llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
- llvm::SmallVector<mlir::Attribute> dependTypeOperands;
+ mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps;
// GCC 9.3.0 emits a (probably) bogus warning about an unused variable.
[[maybe_unused]] llvm::omp::Directive directive;
@@ -897,25 +849,19 @@ static OpTy genTargetEnterExitDataUpdateOp(
}
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(directive, ifClauseOperand);
- cp.processDevice(stmtCtx, deviceOperand);
- cp.processDepend(dependTypeOperands, dependOperands);
- cp.processNowait(nowaitAttr);
+ cp.processIf(directive, clauseOps);
+ cp.processDevice(stmtCtx, clauseOps);
+ cp.processDepend(clauseOps);
+ cp.processNowait(clauseOps);
if constexpr (std::is_same_v<OpTy, mlir::omp::TargetUpdateOp>) {
- cp.processMotionClauses<clause::To>(stmtCtx, mapOperands);
- cp.processMotionClauses<clause::From>(stmtCtx, mapOperands);
+ cp.processMotionClauses<clause::To>(stmtCtx, clauseOps);
+ cp.processMotionClauses<clause::From>(stmtCtx, clauseOps);
} else {
- cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
+ cp.processMap(currentLocation, directive, stmtCtx, clauseOps);
}
- return firOpBuilder.create<OpTy>(
- currentLocation, ifClauseOperand, deviceOperand,
- dependTypeOperands.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- dependTypeOperands),
- dependOperands, nowaitAttr, mapOperands);
+ return firOpBuilder.create<OpTy>(currentLocation, clauseOps);
}
// This functions creates a block for the body of the targetOp's region. It adds
@@ -925,9 +871,9 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
mlir::omp::TargetOp &targetOp,
- llvm::ArrayRef<mlir::Type> mapSymTypes,
+ llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSyms,
llvm::ArrayRef<mlir::Location> mapSymLocs,
- llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSymbols,
+ llvm::ArrayRef<mlir::Type> mapSymTypes,
const mlir::Location ¤tLocation) {
assert(mapSymTypes.size() == mapSymLocs.size());
@@ -956,7 +902,7 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
};
// Bind the symbols to their corresponding block arguments.
- for (auto [argIndex, argSymbol] : llvm::enumerate(mapSymbols)) {
+ for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) {
const mlir::BlockArgument &arg = region.getArgument(argIndex);
// Avoid capture of a reference to a structured binding.
const Fortran::semantics::Symbol *sym = argSymbol;
@@ -1080,22 +1026,20 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpClauseList &clauseList,
llvm::omp::Directive directive, bool outerCombined = false) {
Fortran::lower::StatementContext stmtCtx;
- mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
- mlir::UnitAttr nowaitAttr;
- llvm::SmallVector<mlir::Attribute> dependTypeOperands;
- llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
- llvm::SmallVector<mlir::Type> mapSymTypes;
+ mlir::omp::TargetClauseOps clauseOps;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> mapSyms;
llvm::SmallVector<mlir::Location> mapSymLocs;
- llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
+ llvm::SmallVector<mlir::Type> mapSymTypes;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_target, ifClauseOperand);
- cp.processDevice(stmtCtx, deviceOperand);
- cp.processThreadLimit(stmtCtx, threadLimitOperand);
- cp.processDepend(dependTypeOperands, dependOperands);
- cp.processNowait(nowaitAttr);
- cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes,
- &mapSymLocs, &mapSymbols);
+ cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
+ cp.processDevice(stmtCtx, clauseOps);
+ cp.processThreadLimit(stmtCtx, clauseOps);
+ cp.processDepend(clauseOps);
+ cp.processNowait(clauseOps);
+ cp.processMap(currentLocation, directive, stmtCtx, clauseOps, &mapSyms,
+ &mapSymLocs, &mapSymTypes);
+ // TODO Support delayed privatization.
cp.processTODO<clause::Private, clause::Firstprivate, clause::IsDevicePtr,
clause::HasDeviceAddr, clause::Reduction, clause::InReduction,
@@ -1107,7 +1051,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
// symbols used inside the region that have not been explicitly mapped using
// the map clause.
auto captureImplicitMap = [&](const Fortran::semantics::Symbol &sym) {
- if (llvm::find(mapSymbols, &sym) == mapSymbols.end()) {
+ if (llvm::find(mapSyms, &sym) == mapSyms.end()) {
mlir::Value baseOp = converter.getSymbolAddress(sym);
if (!baseOp)
if (const auto *details = sym.template detailsIf<
@@ -1178,25 +1122,20 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
mapFlag),
captureKind, baseOp.getType());
- mapOperands.push_back(mapOp);
- mapSymTypes.push_back(baseOp.getType());
+ clauseOps.mapVars.push_back(mapOp);
+ mapSyms.push_back(&sym);
mapSymLocs.push_back(baseOp.getLoc());
- mapSymbols.push_back(&sym);
+ mapSymTypes.push_back(baseOp.getType());
}
}
};
Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap);
auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
- currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
- dependTypeOperands.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- dependTypeOperands),
- dependOperands, nowaitAttr, mapOperands);
+ currentLocation, clauseOps);
- genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
- mapSymLocs, mapSymbols, currentLocation);
+ genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSyms,
+ mapSymLocs, mapSymTypes, currentLocation);
return targetOp;
}
@@ -1209,17 +1148,16 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpClauseList &clauseList,
bool outerCombined = false) {
Fortran::lower::StatementContext stmtCtx;
- mlir::Value numTeamsClauseOperand, ifClauseOperand, threadLimitClauseOperand;
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
- reductionVars;
- llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
+ mlir::omp::TeamsClauseOps clauseOps;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_teams, ifClauseOperand);
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
+ cp.processAllocate(clauseOps);
cp.processDefault();
- cp.processNumTeams(stmtCtx, numTeamsClauseOperand);
- cp.processThreadLimit(stmtCtx, threadLimitClauseOperand);
+ cp.processNumTeams(stmtCtx, clauseOps);
+ cp.processThreadLimit(stmtCtx, clauseOps);
+ // TODO Support delayed privatization.
+
cp.processTODO<clause::Reduction>(currentLocation,
llvm::omp::Directive::OMPD_teams);
@@ -1228,30 +1166,20 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
.setGenNested(genNested)
.setOuterCombined(outerCombined)
.setClauses(&clauseList),
- /*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
- threadLimitClauseOperand, allocateOperands, allocatorOperands,
- reductionVars,
- reductionDeclSymbols.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- reductionDeclSymbols));
+ clauseOps);
}
/// Extract the list of function and variable symbols affected by the given
/// 'declare target' directive and return the intended device type for them.
-static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
+static void getDeclareTargetInfo(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
+ mlir::omp::DeclareTargetClauseOps &clauseOps,
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
-
- // The default capture type
- mlir::omp::DeclareTargetDeviceType deviceType =
- mlir::omp::DeclareTargetDeviceType::any;
const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
declareTargetConstruct.t);
-
if (const auto *objectList{
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
ObjectList objects{makeList(*objectList, semaCtx)};
@@ -1272,12 +1200,10 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
cp.processTo(symbolAndClause);
cp.processEnter(symbolAndClause);
cp.processLink(symbolAndClause);
- cp.processDeviceType(deviceType);
+ cp.processDeviceType(clauseOps);
cp.processTODO<clause::Indirect>(converter.getCurrentLocation(),
llvm::omp::Directive::OMPD_declare_target);
}
-
- return deviceType;
}
static void collectDeferredDeclareTargets(
@@ -1287,9 +1213,10 @@ static void collectDeferredDeclareTargets(
const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
llvm::SmallVectorImpl<Fortran::lower::OMPDeferredDeclareTargetInfo>
&deferredDeclareTarget) {
+ mlir::omp::DeclareTargetClauseOps clauseOps;
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
- mlir::omp::DeclareTargetDeviceType devType = getDeclareTargetInfo(
- converter, semaCtx, eval, declareTargetConstruct, symbolAndClause);
+ getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
+ clauseOps, symbolAndClause);
// Return the device type only if at least one of the targets for the
// directive is a function or subroutine
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
@@ -1299,8 +1226,9 @@ static void collectDeferredDeclareTargets(
std::get<const Fortran::semantics::Symbol &>(symClause)));
if (!op) {
- deferredDeclareTarget.push_back(
- {std::get<0>(symClause), devType, std::get<1>(symClause)});
+ deferredDeclareTarget.push_back({std::get<0>(symClause),
+ clauseOps.deviceType,
+ std::get<1>(symClause)});
}
}
}
@@ -1312,9 +1240,10 @@ getDeclareTargetFunctionDevice(
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclareTargetConstruct
&declareTargetConstruct) {
+ mlir::omp::DeclareTargetClauseOps clauseOps;
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
- mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
- converter, semaCtx, eval, declareTargetConstruct, symbolAndClause);
+ getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
+ clauseOps, symbolAndClause);
// Return the device type only if at least one of the targets for the
// directive is a function or subroutine
@@ -1324,7 +1253,7 @@ getDeclareTargetFunctionDevice(
std::get<const Fortran::semantics::Symbol &>(symClause)));
if (mlir::isa_and_nonnull<mlir::func::FuncOp>(op))
- return deviceType;
+ return clauseOps.deviceType;
}
return std::nullopt;
@@ -1354,12 +1283,14 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_barrier:
firOpBuilder.create<mlir::omp::BarrierOp>(currentLocation);
break;
- case llvm::omp::Directive::OMPD_taskwait:
- ClauseProcessor(converter, semaCtx, opClauseList)
- .processTODO<clause::Depend, clause::Nowait>(
- currentLocation, llvm::omp::Directive::OMPD_taskwait);
- firOpBuilder.create<mlir::omp::TaskwaitOp>(currentLocation);
+ case llvm::omp::Directive::OMPD_taskwait: {
+ mlir::omp::TaskwaitClauseOps clauseOps;
+ ClauseProcessor cp(converter, semaCtx, opClauseList);
+ cp.processTODO<clause::Depend, clause::Nowait>(
+ currentLocation, llvm::omp::Directive::OMPD_taskwait);
+ firOpBuilder.create<mlir::omp::TaskwaitOp>(currentLocation, clauseOps);
break;
+ }
case llvm::omp::Directive::OMPD_taskyield:
firOpBuilder.create<mlir::omp::TaskyieldOp>(currentLocation);
break;
@@ -1494,32 +1425,21 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
- mlir::Value scheduleChunkClauseOperand, ifClauseOperand;
- llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, reductionVars;
- llvm::SmallVector<mlir::Value> alignedVars, nontemporalVars;
+ mlir::omp::SimdLoopClauseOps clauseOps;
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
- llvm::SmallVector<mlir::Type> reductionTypes;
- llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
- mlir::omp::ClauseOrderKindAttr orderClauseOperand;
- mlir::IntegerAttr simdlenClauseOperand, safelenClauseOperand;
ClauseProcessor cp(converter, semaCtx, loopOpClauseList);
- cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv);
- cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
- cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols);
- cp.processIf(llvm::omp::Directive::OMPD_simd, ifClauseOperand);
- cp.processSimdlen(simdlenClauseOperand);
- cp.processSafelen(safelenClauseOperand);
+ cp.processCollapse(loc, eval, clauseOps, iv);
+ cp.processReduction(loc, clauseOps);
+ cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps);
+ cp.processSimdlen(clauseOps);
+ cp.processSafelen(clauseOps);
+ clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr();
+ // TODO Support delayed privatization.
+
cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
clause::Nontemporal, clause::Order>(loc, ompDirective);
- mlir::TypeRange resultType;
- auto simdLoopOp = firOpBuilder.create<mlir::omp::SimdLoopOp>(
- loc, resultType, lowerBound, upperBound, step, alignedVars,
- /*alignment_values=*/nullptr, ifClauseOperand, nontemporalVars,
- orderClauseOperand, simdlenClauseOperand, safelenClauseOperand,
- /*inclusive=*/firOpBuilder.getUnitAttr());
-
auto *nestedEval = getCollapsedLoopEval(
eval, Fortran::lower::getCollapseValue(loopOpClauseList));
@@ -1527,11 +1447,12 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
return genLoopVars(op, converter, loc, iv);
};
- createBodyOfOp<mlir::omp::SimdLoopOp>(
- simdLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
- .setClauses(&loopOpClauseList)
- .setDataSharingProcessor(&dsp)
- .setGenRegionEntryCb(ivCallback));
+ genOpWithBody<mlir::omp::SimdLoopOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
+ .setClauses(&loopOpClauseList)
+ .setDataSharingProcessor(&dsp)
+ .setGenRegionEntryCb(ivCallback),
+ clauseOps);
}
static void createWsloop(Fortran::lower::AbstractConverter &converter,
@@ -1546,77 +1467,50 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter,
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
- mlir::Value scheduleChunkClauseOperand;
- llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, reductionVars;
- llvm::SmallVector<mlir::Value> linearVars, linearStepVars;
+ mlir::omp::WsloopClauseOps clauseOps;
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
llvm::SmallVector<mlir::Type> reductionTypes;
- llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
- llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
- mlir::omp::ClauseOrderKindAttr orderClauseOperand;
- mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
- mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand;
- mlir::IntegerAttr orderedClauseOperand;
- mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
ClauseProcessor cp(converter, semaCtx, beginClauseList);
- cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv);
- cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
- cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols,
- &reductionSymbols);
- cp.processTODO<clause::Linear, clause::Order>(loc, ompDirective);
-
- if (ReductionProcessor::doReductionByRef(reductionVars))
- byrefOperand = firOpBuilder.getUnitAttr();
-
- auto wsLoopOp = firOpBuilder.create<mlir::omp::WsloopOp>(
- loc, lowerBound, upperBound, step, linearVars, linearStepVars,
- reductionVars,
- reductionDeclSymbols.empty()
- ? nullptr
- : mlir::ArrayAttr::get(firOpBuilder.getContext(),
- reductionDeclSymbols),
- scheduleValClauseOperand, scheduleChunkClauseOperand,
- /*schedule_modifiers=*/nullptr,
- /*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand,
- orderedClauseOperand, orderClauseOperand,
- /*inclusive=*/firOpBuilder.getUnitAttr());
-
- // Handle attribute based clauses.
- if (cp.processOrdered(orderedClauseOperand))
- wsLoopOp.setOrderedValAttr(orderedClauseOperand);
-
- if (cp.processSchedule(scheduleValClauseOperand, scheduleModClauseOperand,
- scheduleSimdClauseOperand)) {
- wsLoopOp.setScheduleValAttr(scheduleValClauseOperand);
- wsLoopOp.setScheduleModifierAttr(scheduleModClauseOperand);
- wsLoopOp.setSimdModifierAttr(scheduleSimdClauseOperand);
- }
+ cp.processCollapse(loc, eval, clauseOps, iv);
+ cp.processSchedule(stmtCtx, clauseOps);
+ cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
+ cp.processOrdered(clauseOps);
+ clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr();
+ // TODO Support delayed privatization.
+
+ if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
+ clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
+
+ cp.processTODO<clause::Allocate, clause::Linear, clause::Order>(loc,
+ ompDirective);
+
// In FORTRAN `nowait` clause occur at the end of `omp do` directive.
// i.e
// !$omp do
// <...>
// !$omp end do nowait
if (endClauseList) {
- if (ClauseProcessor(converter, semaCtx, *endClauseList)
- .processNowait(nowaitClauseOperand))
- wsLoopOp.setNowaitAttr(nowaitClauseOperand);
+ ClauseProcessor ecp(converter, semaCtx, *endClauseList);
+ ecp.processNowait(clauseOps);
}
auto *nestedEval = getCollapsedLoopEval(
eval, Fortran::lower::getCollapseValue(beginClauseList));
auto ivCallback = [&](mlir::Operation *op) {
- return genLoopAndReductionVars(op, converter, loc, iv, reductionSymbols,
+ return genLoopAndReductionVars(op, converter, loc, iv, reductionSyms,
reductionTypes);
};
- createBodyOfOp<mlir::omp::WsloopOp>(
- wsLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
- .setClauses(&beginClauseList)
- .setDataSharingProcessor(&dsp)
- .setReductions(&reductionSymbols, &reductionTypes)
- .setGenRegionEntryCb(ivCallback));
+ genOpWithBody<mlir::omp::WsloopOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
+ .setClauses(&beginClauseList)
+ .setDataSharingProcessor(&dsp)
+ .setReductions(&reductionSyms, &reductionTypes)
+ .setGenRegionEntryCb(ivCallback),
+ clauseOps);
}
static void createSimdWsloop(
@@ -1704,10 +1598,11 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclareTargetConstruct
&declareTargetConstruct) {
+ mlir::omp::DeclareTargetClauseOps clauseOps;
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
- mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
- converter, semaCtx, eval, declareTargetConstruct, symbolAndClause);
+ getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
+ clauseOps, symbolAndClause);
for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
mlir::Operation *op = mod.lookupSymbol(converter.mangleName(
@@ -1721,7 +1616,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
markDeclareTarget(
op, converter,
- std::get<mlir::omp::DeclareTargetCaptureClause>(symClause), deviceType);
+ std::get<mlir::omp::DeclareTargetCaptureClause>(symClause),
+ clauseOps.deviceType);
}
}
@@ -1853,7 +1749,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
!std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u)) {
+ !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u) &&
+ !std::get_if<Fortran::parser::OmpClause::Simd>(&clause.u)) {
TODO(clauseLocation, "OpenMP Block construct clause");
}
}
@@ -1873,7 +1770,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
break;
case llvm::omp::Directive::OMPD_ordered:
genOrderedRegionOp(converter, semaCtx, eval, /*genNested=*/true,
- currentLocation);
+ currentLocation, beginClauseList);
break;
case llvm::omp::Directive::OMPD_parallel:
genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/true,
@@ -1964,7 +1861,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();
- mlir::IntegerAttr hintClauseOp;
std::string name;
const Fortran::parser::OmpCriticalDirective &cd =
std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t);
@@ -1973,21 +1869,28 @@ genOMP(Fortran::lower::AbstractConverter &converter,
std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString();
}
- const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
- ClauseProcessor(converter, semaCtx, clauseList).processHint(hintClauseOp);
-
mlir::omp::CriticalOp criticalOp = [&]() {
if (name.empty()) {
return firOpBuilder.create<mlir::omp::CriticalOp>(
currentLocation, mlir::FlatSymbolRefAttr());
}
+
mlir::ModuleOp module = firOpBuilder.getModule();
mlir::OpBuilder modBuilder(module.getBodyRegion());
auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name);
- if (!global)
- global = modBuilder.create<mlir::omp::CriticalDeclareOp>(
- currentLocation,
- mlir::StringAttr::get(firOpBuilder.getContext(), name), hintClauseOp);
+ if (!global) {
+ mlir::omp::CriticalClauseOps clauseOps;
+ const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
+
+ ClauseProcessor cp(converter, semaCtx, clauseList);
+ cp.processHint(clauseOps);
+ clauseOps.nameAttr =
+ mlir::StringAttr::get(firOpBuilder.getContext(), name);
+
+ global = modBuilder.create<mlir::omp::CriticalDeclareOp>(currentLocation,
+ clauseOps);
+ }
+
return firOpBuilder.create<mlir::omp::CriticalOp>(
currentLocation, mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(),
global.getSymName()));
@@ -2104,8 +2007,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) {
mlir::Location currentLocation = converter.getCurrentLocation();
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
- mlir::UnitAttr nowaitClauseOperand;
+ mlir::omp::SectionsClauseOps clauseOps;
const auto &beginSectionsDirective =
std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t);
const auto §ionsClauseList =
@@ -2114,8 +2016,9 @@ genOMP(Fortran::lower::AbstractConverter &converter,
// Process clauses before optional omp.parallel, so that new variables are
// allocated outside of the parallel region
ClauseProcessor cp(converter, semaCtx, sectionsClauseList);
- cp.processSectionsReduction(currentLocation);
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processSectionsReduction(currentLocation, clauseOps);
+ cp.processAllocate(clauseOps);
+ // TODO Support delayed privatization.
llvm::omp::Directive dir =
std::get<Fortran::parser::OmpSectionsDirective>(beginSectionsDirective.t)
@@ -2132,16 +2035,14 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const auto &endSectionsClauseList =
std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t);
ClauseProcessor(converter, semaCtx, endSectionsClauseList)
- .processNowait(nowaitClauseOperand);
+ .processNowait(clauseOps);
}
// SECTIONS construct
genOpWithBody<mlir::omp::SectionsOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(false),
- /*reduction_vars=*/mlir::ValueRange(),
- /*reductions=*/nullptr, allocateOperands, allocatorOperands,
- nowaitClauseOperand);
+ clauseOps);
const auto §ionBlocks =
std::get<Fortran::parser::OmpSectionBlocks>(sectionsConstruct.t);
More information about the flang-commits
mailing list