[flang-commits] [flang] [llvm] [mlir] [Frontend][OpenMP] Implement getLeafOrCompositeConstructs (PR #89104)
Krzysztof Parzyszek via flang-commits
flang-commits at lists.llvm.org
Tue Apr 23 09:03:21 PDT 2024
https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/89104
>From c9ee93f1c28e9d4518ec842f66f497bf0911c9d5 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 26 Mar 2024 16:31:34 +0000
Subject: [PATCH 01/17] [MLIR][OpenMP] Group clause operands into structures
This patch introduces a set of composable structures grouping the MLIR operands
associated to each OpenMP clause. This makes it easier to keep the MLIR
representation for the same clause consistent throughout all operations that
accept it.
The relevant clause operand structures are grouped into per-operation
structures using a mixin pattern and used to define new operation constructors.
These constructors can be used to avoid having to get the order of a possibly
large list of operands right.
Missing clauses are documented as TODOs, as well as operands which are part of
the relevant operation's operand structure but cannot be attached to the
associated operation yet, due to missing op arguments to its MLIR definition.
A follow-up patch will update Flang lowering to make use of these structures,
simplifying the passing of information from clause processing to operation-
generating functions and also simplifying the creation of operations through
the use of the new operation constructors.
---
.../Dialect/OpenMP/OpenMPClauseOperands.h | 300 ++++++++++++++++++
.../mlir/Dialect/OpenMP/OpenMPDialect.h | 7 +-
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 72 ++++-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 226 ++++++++++++-
4 files changed, 595 insertions(+), 10 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
new file mode 100644
index 00000000000000..6454076f7593b3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -0,0 +1,300 @@
+//===-- OpenMPClauseOperands.h ----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the structures defining MLIR operands associated with each
+// OpenMP clause, and structures grouping the appropriate operands for each
+// construct.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_
+#define MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "llvm/ADT/SmallVector.h"
+
+#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h.inc"
+
+namespace mlir {
+namespace omp {
+
+//===----------------------------------------------------------------------===//
+// Mixin structures defining MLIR operands associated with each OpenMP clause.
+//===----------------------------------------------------------------------===//
+
+struct AlignedClauseOps {
+ llvm::SmallVector<Value> alignedVars;
+ llvm::SmallVector<Attribute> alignmentAttrs;
+};
+
+struct AllocateClauseOps {
+ llvm::SmallVector<Value> allocatorVars, allocateVars;
+};
+
+struct CollapseClauseOps {
+ llvm::SmallVector<Value> loopLBVar, loopUBVar, loopStepVar;
+};
+
+struct CopyprivateClauseOps {
+ llvm::SmallVector<Value> copyprivateVars;
+ llvm::SmallVector<Attribute> copyprivateFuncs;
+};
+
+struct DependClauseOps {
+ llvm::SmallVector<Attribute> dependTypeAttrs;
+ llvm::SmallVector<Value> dependVars;
+};
+
+struct DeviceClauseOps {
+ Value deviceVar;
+};
+
+struct DeviceTypeClauseOps {
+ // The default capture type.
+ DeclareTargetDeviceType deviceType = DeclareTargetDeviceType::any;
+};
+
+struct DistScheduleClauseOps {
+ UnitAttr distScheduleStaticAttr;
+ Value distScheduleChunkSizeVar;
+};
+
+struct DoacrossClauseOps {
+ llvm::SmallVector<Value> doacrossVectorVars;
+ ClauseDependAttr doacrossDependTypeAttr;
+ IntegerAttr doacrossNumLoopsAttr;
+};
+
+struct FinalClauseOps {
+ Value finalVar;
+};
+
+struct GrainsizeClauseOps {
+ Value grainsizeVar;
+};
+
+struct HintClauseOps {
+ IntegerAttr hintAttr;
+};
+
+struct IfClauseOps {
+ Value ifVar;
+};
+
+struct InReductionClauseOps {
+ llvm::SmallVector<Value> inReductionVars;
+ llvm::SmallVector<Attribute> inReductionDeclSymbols;
+};
+
+struct LinearClauseOps {
+ llvm::SmallVector<Value> linearVars, linearStepVars;
+};
+
+struct LoopRelatedOps {
+ UnitAttr loopInclusiveAttr;
+};
+
+struct MapClauseOps {
+ llvm::SmallVector<Value> mapVars;
+};
+
+struct MergeableClauseOps {
+ UnitAttr mergeableAttr;
+};
+
+struct NameClauseOps {
+ StringAttr nameAttr;
+};
+
+struct NogroupClauseOps {
+ UnitAttr nogroupAttr;
+};
+
+struct NontemporalClauseOps {
+ llvm::SmallVector<Value> nontemporalVars;
+};
+
+struct NowaitClauseOps {
+ UnitAttr nowaitAttr;
+};
+
+struct NumTasksClauseOps {
+ Value numTasksVar;
+};
+
+struct NumTeamsClauseOps {
+ Value numTeamsLowerVar, numTeamsUpperVar;
+};
+
+struct NumThreadsClauseOps {
+ Value numThreadsVar;
+};
+
+struct OrderClauseOps {
+ ClauseOrderKindAttr orderAttr;
+};
+
+struct OrderedClauseOps {
+ IntegerAttr orderedAttr;
+};
+
+struct ParallelizationLevelClauseOps {
+ UnitAttr parLevelSimdAttr;
+};
+
+struct PriorityClauseOps {
+ Value priorityVar;
+};
+
+struct PrivateClauseOps {
+ // SSA values that correspond to "original" values being privatized.
+ // They refer to the SSA value outside the OpenMP region from which a clone is
+ // created inside the region.
+ llvm::SmallVector<Value> privateVars;
+ // The list of symbols referring to delayed privatizer ops (i.e. `omp.private`
+ // ops).
+ llvm::SmallVector<Attribute> privatizers;
+};
+
+struct ProcBindClauseOps {
+ ClauseProcBindKindAttr procBindKindAttr;
+};
+
+struct ReductionClauseOps {
+ llvm::SmallVector<Value> reductionVars;
+ llvm::SmallVector<Attribute> reductionDeclSymbols;
+ UnitAttr reductionByRefAttr;
+};
+
+struct SafelenClauseOps {
+ IntegerAttr safelenAttr;
+};
+
+struct ScheduleClauseOps {
+ ClauseScheduleKindAttr scheduleValAttr;
+ ScheduleModifierAttr scheduleModAttr;
+ Value scheduleChunkVar;
+ UnitAttr scheduleSimdAttr;
+};
+
+struct SimdlenClauseOps {
+ IntegerAttr simdlenAttr;
+};
+
+struct TaskReductionClauseOps {
+ llvm::SmallVector<Value> taskReductionVars;
+ llvm::SmallVector<Attribute> taskReductionDeclSymbols;
+};
+
+struct ThreadLimitClauseOps {
+ Value threadLimitVar;
+};
+
+struct UntiedClauseOps {
+ UnitAttr untiedAttr;
+};
+
+struct UseDeviceClauseOps {
+ llvm::SmallVector<Value> useDevicePtrVars, useDeviceAddrVars;
+};
+
+//===----------------------------------------------------------------------===//
+// Structures defining clause operands associated with each OpenMP leaf
+// construct.
+//
+// These mirror the arguments expected by the corresponding OpenMP MLIR ops.
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+template <typename... Mixins>
+struct Clauses : public Mixins... {};
+} // namespace detail
+
+using CriticalClauseOps = detail::Clauses<HintClauseOps, NameClauseOps>;
+
+// TODO `indirect` clause.
+using DeclareTargetClauseOps = detail::Clauses<DeviceTypeClauseOps>;
+
+using DistributeClauseOps =
+ detail::Clauses<AllocateClauseOps, DistScheduleClauseOps, OrderClauseOps,
+ PrivateClauseOps>;
+
+// TODO `filter` clause.
+using MaskedClauseOps = detail::Clauses<>;
+
+using OrderedOpClauseOps = detail::Clauses<DoacrossClauseOps>;
+
+using OrderedRegionClauseOps = detail::Clauses<ParallelizationLevelClauseOps>;
+
+using ParallelClauseOps =
+ detail::Clauses<AllocateClauseOps, IfClauseOps, NumThreadsClauseOps,
+ PrivateClauseOps, ProcBindClauseOps, ReductionClauseOps>;
+
+using SectionsClauseOps = detail::Clauses<AllocateClauseOps, NowaitClauseOps,
+ PrivateClauseOps, ReductionClauseOps>;
+
+// TODO `linear` clause.
+using SimdLoopClauseOps =
+ detail::Clauses<AlignedClauseOps, CollapseClauseOps, IfClauseOps,
+ LoopRelatedOps, NontemporalClauseOps, OrderClauseOps,
+ PrivateClauseOps, ReductionClauseOps, SafelenClauseOps,
+ SimdlenClauseOps>;
+
+using SingleClauseOps = detail::Clauses<AllocateClauseOps, CopyprivateClauseOps,
+ NowaitClauseOps, PrivateClauseOps>;
+
+// TODO `defaultmap`, `has_device_addr`, `is_device_ptr`, `uses_allocators`
+// clauses.
+using TargetClauseOps =
+ detail::Clauses<AllocateClauseOps, DependClauseOps, DeviceClauseOps,
+ IfClauseOps, InReductionClauseOps, MapClauseOps,
+ NowaitClauseOps, PrivateClauseOps, ReductionClauseOps,
+ ThreadLimitClauseOps>;
+
+using TargetDataClauseOps = detail::Clauses<DeviceClauseOps, IfClauseOps,
+ MapClauseOps, UseDeviceClauseOps>;
+
+using TargetEnterExitUpdateDataClauseOps =
+ detail::Clauses<DependClauseOps, DeviceClauseOps, IfClauseOps, MapClauseOps,
+ NowaitClauseOps>;
+
+// TODO `affinity`, `detach` clauses.
+using TaskClauseOps =
+ detail::Clauses<AllocateClauseOps, DependClauseOps, FinalClauseOps,
+ IfClauseOps, InReductionClauseOps, MergeableClauseOps,
+ PriorityClauseOps, PrivateClauseOps, UntiedClauseOps>;
+
+using TaskgroupClauseOps =
+ detail::Clauses<AllocateClauseOps, TaskReductionClauseOps>;
+
+using TaskloopClauseOps =
+ detail::Clauses<AllocateClauseOps, CollapseClauseOps, FinalClauseOps,
+ GrainsizeClauseOps, IfClauseOps, InReductionClauseOps,
+ LoopRelatedOps, MergeableClauseOps, NogroupClauseOps,
+ NumTasksClauseOps, PriorityClauseOps, PrivateClauseOps,
+ ReductionClauseOps, UntiedClauseOps>;
+
+using TaskwaitClauseOps = detail::Clauses<DependClauseOps, NowaitClauseOps>;
+
+using TeamsClauseOps =
+ detail::Clauses<AllocateClauseOps, IfClauseOps, NumTeamsClauseOps,
+ PrivateClauseOps, ReductionClauseOps, ThreadLimitClauseOps>;
+
+using WsloopClauseOps =
+ detail::Clauses<AllocateClauseOps, CollapseClauseOps, LinearClauseOps,
+ LoopRelatedOps, NowaitClauseOps, OrderClauseOps,
+ OrderedClauseOps, PrivateClauseOps, ReductionClauseOps,
+ ScheduleClauseOps>;
+
+} // namespace omp
+} // namespace mlir
+
+#endif // MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
index 23509c5b607016..c656bdc870976f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
@@ -26,11 +26,10 @@
#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.h.inc"
#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc"
-#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc"
-#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc"
-#define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h.inc"
+#include "mlir/Dialect/OpenMP/OpenMPClauseOperands.h"
+
+#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index f33942b3c7c02d..2643348d668698 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -287,7 +287,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
let regions = (region AnyRegion:$region);
let builders = [
- OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
+ OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+ OpBuilder<(ins CArg<"const ParallelClauseOps &">:$clauses)>
];
let extraClassDeclaration = [{
/// Returns the number of reduction variables.
@@ -362,6 +363,10 @@ def TeamsOp : OpenMP_Op<"teams", [
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const TeamsClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(
`num_teams` `(` ( $num_teams_lower^ `:` type($num_teams_lower) )? `to`
@@ -451,6 +456,10 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
let regions = (region SizedRegion<1>:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const SectionsClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist( `reduction` `(`
custom<ReductionVarList>(
@@ -495,6 +504,10 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const SingleClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`allocate` `(`
custom<AllocateAndAllocator>(
@@ -601,6 +614,7 @@ def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
OpBuilder<(ins "ValueRange":$lowerBound, "ValueRange":$upperBound,
"ValueRange":$step,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+ OpBuilder<(ins CArg<"const WsloopClauseOps &">:$clauses)>
];
let regions = (region AnyRegion:$region);
@@ -698,6 +712,11 @@ def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
);
let regions = (region AnyRegion:$region);
+
+ let builders = [
+ OpBuilder<(ins CArg<"const SimdLoopClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`aligned` `(`
custom<AlignedClause>($aligned_vars, type($aligned_vars),
@@ -781,6 +800,10 @@ def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments,
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const DistributeClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`dist_schedule_static` $dist_schedule_static
|`chunk_size` `(` $chunk_size `:` type($chunk_size) `)`
@@ -883,6 +906,9 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
Variadic<AnyType>:$allocate_vars,
Variadic<AnyType>:$allocators_vars);
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const TaskClauseOps &">:$clauses)>
+ ];
let assemblyFormat = [{
oilist(`if` `(` $if_expr `)`
|`final` `(` $final_expr `)`
@@ -1037,6 +1063,10 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const TaskloopClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`if` `(` $if_expr `)`
|`final` `(` $final_expr `)`
@@ -1106,6 +1136,10 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", [AttrSizedOperandSegments,
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const TaskgroupClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`task_reduction` `(`
custom<ReductionVarList>(
@@ -1432,6 +1466,10 @@ def TargetDataOp: OpenMP_Op<"target_data", [AttrSizedOperandSegments,
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const TargetDataClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`if` `(` $if_expr `:` type($if_expr) `)`
| `device` `(` $device `:` type($device) `)`
@@ -1486,6 +1524,10 @@ def TargetEnterDataOp: OpenMP_Op<"target_enter_data",
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);
+ let builders = [
+ OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`if` `(` $if_expr `:` type($if_expr) `)`
| `device` `(` $device `:` type($device) `)`
@@ -1540,6 +1582,10 @@ def TargetExitDataOp: OpenMP_Op<"target_exit_data",
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);
+ let builders = [
+ OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`if` `(` $if_expr `:` type($if_expr) `)`
| `device` `(` $device `:` type($device) `)`
@@ -1596,6 +1642,10 @@ def TargetUpdateOp: OpenMP_Op<"target_update", [AttrSizedOperandSegments,
UnitAttr:$nowait,
Variadic<OpenMP_PointerLikeType>:$map_operands);
+ let builders = [
+ OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist(`if` `(` $if_expr `:` type($if_expr) `)`
| `device` `(` $device `:` type($device) `)`
@@ -1649,6 +1699,10 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const TargetClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
oilist( `if` `(` $if_expr `)`
| `device` `(` $device `:` type($device) `)`
@@ -1693,6 +1747,10 @@ def CriticalDeclareOp : OpenMP_Op<"critical.declare", [Symbol]> {
let arguments = (ins SymbolNameAttr:$sym_name,
DefaultValuedAttr<I64Attr, "0">:$hint_val);
+ let builders = [
+ OpBuilder<(ins CArg<"const CriticalClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
$sym_name oilist(`hint` `(` custom<SynchronizationHint>($hint_val) `)`)
attr-dict
@@ -1773,6 +1831,10 @@ def OrderedOp : OpenMP_Op<"ordered"> {
ConfinedAttr<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$num_loops_val,
Variadic<AnyType>:$depend_vec_vars);
+ let builders = [
+ OpBuilder<(ins CArg<"const OrderedOpClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{
( `depend_type` `` $depend_type_val^ )?
( `depend_vec` `(` $depend_vec_vars^ `:` type($depend_vec_vars) `)` )?
@@ -1797,6 +1859,10 @@ def OrderedRegionOp : OpenMP_Op<"ordered.region"> {
let regions = (region AnyRegion:$region);
+ let builders = [
+ OpBuilder<(ins CArg<"const OrderedRegionClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = [{ ( `simd` $simd^ )? $region attr-dict}];
let hasVerifier = 1;
}
@@ -1812,6 +1878,10 @@ def TaskwaitOp : OpenMP_Op<"taskwait"> {
of the current task.
}];
+ let builders = [
+ OpBuilder<(ins CArg<"const TaskwaitClauseOps &">:$clauses)>
+ ];
+
let assemblyFormat = "attr-dict";
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index bf5875071e0dc4..28869c1ddfb3fd 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -41,6 +41,11 @@
using namespace mlir;
using namespace mlir::omp;
+static ArrayAttr makeArrayAttr(MLIRContext *context,
+ llvm::ArrayRef<Attribute> attrs) {
+ return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
+}
+
namespace {
struct MemRefPointerLikeModel
: public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
@@ -1161,6 +1166,17 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
return success();
}
+//===----------------------------------------------------------------------===//
+// TargetDataOp
+//===----------------------------------------------------------------------===//
+
+void TargetDataOp::build(OpBuilder &builder, OperationState &state,
+ const TargetDataClauseOps &clauses) {
+ TargetDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
+ clauses.useDevicePtrVars, clauses.useDeviceAddrVars,
+ clauses.mapVars);
+}
+
LogicalResult TargetDataOp::verify() {
if (getMapOperands().empty() && getUseDevicePtr().empty() &&
getUseDeviceAddr().empty()) {
@@ -1170,6 +1186,20 @@ LogicalResult TargetDataOp::verify() {
return verifyMapClause(*this, getMapOperands());
}
+//===----------------------------------------------------------------------===//
+// TargetEnterDataOp
+//===----------------------------------------------------------------------===//
+
+void TargetEnterDataOp::build(
+ OpBuilder &builder, OperationState &state,
+ const TargetEnterExitUpdateDataClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ TargetEnterDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
+ makeArrayAttr(ctx, clauses.dependTypeAttrs),
+ clauses.dependVars, clauses.nowaitAttr,
+ clauses.mapVars);
+}
+
LogicalResult TargetEnterDataOp::verify() {
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
@@ -1177,6 +1207,20 @@ LogicalResult TargetEnterDataOp::verify() {
: verifyMapClause(*this, getMapOperands());
}
+//===----------------------------------------------------------------------===//
+// TargetExitDataOp
+//===----------------------------------------------------------------------===//
+
+void TargetExitDataOp::build(
+ OpBuilder &builder, OperationState &state,
+ const TargetEnterExitUpdateDataClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ TargetExitDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
+ makeArrayAttr(ctx, clauses.dependTypeAttrs),
+ clauses.dependVars, clauses.nowaitAttr,
+ clauses.mapVars);
+}
+
LogicalResult TargetExitDataOp::verify() {
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
@@ -1184,6 +1228,19 @@ LogicalResult TargetExitDataOp::verify() {
: verifyMapClause(*this, getMapOperands());
}
+//===----------------------------------------------------------------------===//
+// TargetUpdateOp
+//===----------------------------------------------------------------------===//
+
+void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
+ const TargetEnterExitUpdateDataClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ TargetUpdateOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
+ makeArrayAttr(ctx, clauses.dependTypeAttrs),
+ clauses.dependVars, clauses.nowaitAttr,
+ clauses.mapVars);
+}
+
LogicalResult TargetUpdateOp::verify() {
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
@@ -1191,6 +1248,22 @@ LogicalResult TargetUpdateOp::verify() {
: verifyMapClause(*this, getMapOperands());
}
+//===----------------------------------------------------------------------===//
+// TargetOp
+//===----------------------------------------------------------------------===//
+
+void TargetOp::build(OpBuilder &builder, OperationState &state,
+ const TargetClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
+ // inReductionDeclSymbols, privateVars, privatizers, reductionVars,
+ // reductionByRefAttr, reductionDeclSymbols.
+ TargetOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
+ clauses.threadLimitVar,
+ makeArrayAttr(ctx, clauses.dependTypeAttrs),
+ clauses.dependVars, clauses.nowaitAttr, clauses.mapVars);
+}
+
LogicalResult TargetOp::verify() {
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
@@ -1213,6 +1286,17 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
state.addAttributes(attributes);
}
+void ParallelOp::build(OpBuilder &builder, OperationState &state,
+ const ParallelClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ ParallelOp::build(
+ builder, state, clauses.ifVar, clauses.numThreadsVar,
+ clauses.allocateVars, clauses.allocatorVars, clauses.reductionVars,
+ makeArrayAttr(ctx, clauses.reductionDeclSymbols),
+ clauses.procBindKindAttr, clauses.privateVars,
+ makeArrayAttr(ctx, clauses.privatizers), clauses.reductionByRefAttr);
+}
+
template <typename OpType>
static LogicalResult verifyPrivateVarList(OpType &op) {
auto privateVars = op.getPrivateVars();
@@ -1280,6 +1364,17 @@ static bool opInGlobalImplicitParallelRegion(Operation *op) {
return true;
}
+void TeamsOp::build(OpBuilder &builder, OperationState &state,
+ const TeamsClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
+ TeamsOp::build(builder, state, clauses.numTeamsLowerVar,
+ clauses.numTeamsUpperVar, clauses.ifVar,
+ clauses.threadLimitVar, clauses.allocateVars,
+ clauses.allocatorVars, clauses.reductionVars,
+ makeArrayAttr(ctx, clauses.reductionDeclSymbols));
+}
+
LogicalResult TeamsOp::verify() {
// Check parent region
// TODO If nested inside of a target region, also check that it does not
@@ -1312,9 +1407,19 @@ LogicalResult TeamsOp::verify() {
}
//===----------------------------------------------------------------------===//
-// Verifier for SectionsOp
+// SectionsOp
//===----------------------------------------------------------------------===//
+void SectionsOp::build(OpBuilder &builder, OperationState &state,
+ const SectionsClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
+ SectionsOp::build(builder, state, clauses.reductionVars,
+ makeArrayAttr(ctx, clauses.reductionDeclSymbols),
+ clauses.allocateVars, clauses.allocatorVars,
+ clauses.nowaitAttr);
+}
+
LogicalResult SectionsOp::verify() {
if (getAllocateVars().size() != getAllocatorsVars().size())
return emitError(
@@ -1334,6 +1439,20 @@ LogicalResult SectionsOp::verifyRegions() {
return success();
}
+//===----------------------------------------------------------------------===//
+// SingleOp
+//===----------------------------------------------------------------------===//
+
+void SingleOp::build(OpBuilder &builder, OperationState &state,
+ const SingleClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: privateVars, privatizers.
+ SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.copyprivateVars,
+ makeArrayAttr(ctx, clauses.copyprivateFuncs),
+ clauses.nowaitAttr);
+}
+
LogicalResult SingleOp::verify() {
// Check for allocate clause restrictions
if (getAllocateVars().size() != getAllocatorsVars().size())
@@ -1481,9 +1600,21 @@ void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion,
}
//===----------------------------------------------------------------------===//
-// Verifier for Simd construct [2.9.3.1]
+// Simd construct [2.9.3.1]
//===----------------------------------------------------------------------===//
+void SimdLoopOp::build(OpBuilder &builder, OperationState &state,
+ const SimdLoopClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: privateVars, reductionByRefAttr, reductionVars,
+ // privatizers, reductionDeclSymbols.
+ SimdLoopOp::build(
+ builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar,
+ clauses.alignedVars, makeArrayAttr(ctx, clauses.alignmentAttrs),
+ clauses.ifVar, clauses.nontemporalVars, clauses.orderAttr,
+ clauses.simdlenAttr, clauses.safelenAttr, clauses.loopInclusiveAttr);
+}
+
LogicalResult SimdLoopOp::verify() {
if (this->getLowerBound().empty()) {
return emitOpError() << "empty lowerbound for simd loop operation";
@@ -1504,9 +1635,17 @@ LogicalResult SimdLoopOp::verify() {
}
//===----------------------------------------------------------------------===//
-// Verifier for Distribute construct [2.9.4.1]
+// Distribute construct [2.9.4.1]
//===----------------------------------------------------------------------===//
+void DistributeOp::build(OpBuilder &builder, OperationState &state,
+ const DistributeClauseOps &clauses) {
+ // TODO Store clauses in op: privateVars, privatizers.
+ DistributeOp::build(builder, state, clauses.distScheduleStaticAttr,
+ clauses.distScheduleChunkSizeVar, clauses.allocateVars,
+ clauses.allocatorVars, clauses.orderAttr);
+}
+
LogicalResult DistributeOp::verify() {
if (this->getChunkSize() && !this->getDistScheduleStatic())
return emitOpError() << "chunk size set without "
@@ -1607,6 +1746,19 @@ LogicalResult ReductionOp::verify() {
//===----------------------------------------------------------------------===//
// TaskOp
//===----------------------------------------------------------------------===//
+
+void TaskOp::build(OpBuilder &builder, OperationState &state,
+ const TaskClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: privateVars, privatizers.
+ TaskOp::build(
+ builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
+ clauses.mergeableAttr, clauses.inReductionVars,
+ makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.priorityVar,
+ makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
+ clauses.allocateVars, clauses.allocatorVars);
+}
+
LogicalResult TaskOp::verify() {
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
@@ -1619,6 +1771,15 @@ LogicalResult TaskOp::verify() {
//===----------------------------------------------------------------------===//
// TaskgroupOp
//===----------------------------------------------------------------------===//
+
+void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
+ const TaskgroupClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ TaskgroupOp::build(builder, state, clauses.taskReductionVars,
+ makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
+ clauses.allocateVars, clauses.allocatorVars);
+}
+
LogicalResult TaskgroupOp::verify() {
return verifyReductionVarList(*this, getTaskReductions(),
getTaskReductionVars());
@@ -1627,6 +1788,21 @@ LogicalResult TaskgroupOp::verify() {
//===----------------------------------------------------------------------===//
// TaskloopOp
//===----------------------------------------------------------------------===//
+
+void TaskloopOp::build(OpBuilder &builder, OperationState &state,
+ const TaskloopClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
+ TaskloopOp::build(
+ builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar,
+ clauses.loopInclusiveAttr, clauses.ifVar, clauses.finalVar,
+ clauses.untiedAttr, clauses.mergeableAttr, clauses.inReductionVars,
+ makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars,
+ makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar,
+ clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar,
+ clauses.numTasksVar, clauses.nogroupAttr);
+}
+
SmallVector<Value> TaskloopOp::getAllReductionVars() {
SmallVector<Value> allReductionNvars(getInReductionVars().begin(),
getInReductionVars().end());
@@ -1680,14 +1856,33 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
state.addAttributes(attributes);
}
+void WsloopOp::build(OpBuilder &builder, OperationState &state,
+ const WsloopClauseOps &clauses) {
+ MLIRContext *ctx = builder.getContext();
+ // TODO Store clauses in op: allocateVars, allocatorVars, privateVars,
+ // privatizers.
+ WsloopOp::build(
+ builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar,
+ clauses.linearVars, clauses.linearStepVars, clauses.reductionVars,
+ makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.scheduleValAttr,
+ clauses.scheduleChunkVar, clauses.scheduleModAttr,
+ clauses.scheduleSimdAttr, clauses.nowaitAttr, clauses.reductionByRefAttr,
+ clauses.orderedAttr, clauses.orderAttr, clauses.loopInclusiveAttr);
+}
+
LogicalResult WsloopOp::verify() {
return verifyReductionVarList(*this, getReductions(), getReductionVars());
}
//===----------------------------------------------------------------------===//
-// Verifier for critical construct (2.17.1)
+// Critical construct (2.17.1)
//===----------------------------------------------------------------------===//
+void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
+ const CriticalClauseOps &clauses) {
+ CriticalDeclareOp::build(builder, state, clauses.nameAttr, clauses.hintAttr);
+}
+
LogicalResult CriticalDeclareOp::verify() {
return verifySynchronizationHint(*this, getHintVal());
}
@@ -1707,9 +1902,15 @@ LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
//===----------------------------------------------------------------------===//
-// Verifier for ordered construct
+// Ordered construct
//===----------------------------------------------------------------------===//
+void OrderedOp::build(OpBuilder &builder, OperationState &state,
+ const OrderedOpClauseOps &clauses) {
+ OrderedOp::build(builder, state, clauses.doacrossDependTypeAttr,
+ clauses.doacrossNumLoopsAttr, clauses.doacrossVectorVars);
+}
+
LogicalResult OrderedOp::verify() {
auto container = (*this)->getParentOfType<WsloopOp>();
if (!container || !container.getOrderedValAttr() ||
@@ -1726,6 +1927,11 @@ LogicalResult OrderedOp::verify() {
return success();
}
+void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
+ const OrderedRegionClauseOps &clauses) {
+ OrderedRegionOp::build(builder, state, clauses.parLevelSimdAttr);
+}
+
LogicalResult OrderedRegionOp::verify() {
// TODO: The code generation for ordered simd directive is not supported yet.
if (getSimd())
@@ -1742,6 +1948,16 @@ LogicalResult OrderedRegionOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// TaskwaitOp
+//===----------------------------------------------------------------------===//
+
+void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
+ const TaskwaitClauseOps &clauses) {
+ // TODO Store clauses in op: dependTypeAttrs, dependVars, nowaitAttr.
+ TaskwaitOp::build(builder, state);
+}
+
//===----------------------------------------------------------------------===//
// Verifier for AtomicReadOp
//===----------------------------------------------------------------------===//
>From 7af7e9d13fc2134e76bb532bfa4313aa3df17924 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 26 Mar 2024 16:46:56 +0000
Subject: [PATCH 02/17] [Flang][OpenMP][Lower] Use clause operand structures
This patch updates Flang lowering to use the new set of OpenMP clause operand
structures and their groupings into directive-specific sets of clause operands.
It simplifies the passing of information from the clause processor and the
creation of operations.
The `DataSharingProcessor` is slightly modified to not hold delayed
privatization state. Instead, optional arguments are added to `processStep1`
which are only passed when delayed privatization is used. This enables using
the clause operand structure for `private` and removes the need for the ad-hoc
`DelayedPrivatizationInfo` structure.
The processing of the `schedule` clause is updated to process the `chunk`
modifier rather than requiring two separate calls to the `ClauseProcessor`.
Lowering of a block-associated `ordered` construct is updated to emit a TODO
error if the `simd` clause is specified, since it is not currently supported by
the `ClauseProcessor` or later compilation stages.
Removed processing of `schedule` from `omp.simdloop`, as it doesn't apply to
`simd` constructs.
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 261 +++++----
flang/lib/Lower/OpenMP/ClauseProcessor.h | 105 ++--
.../lib/Lower/OpenMP/DataSharingProcessor.cpp | 38 +-
flang/lib/Lower/OpenMP/DataSharingProcessor.h | 45 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 517 +++++++-----------
5 files changed, 428 insertions(+), 538 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 0a57a1496289f4..ee1f6c2fbc7e89 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -162,14 +162,13 @@ getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
ifVal);
}
-static void
-addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
- const omp::ObjectList &objects,
- llvm::SmallVectorImpl<mlir::Value> &operands,
- llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- &useDeviceSymbols) {
+static void addUseDeviceClause(
+ Fortran::lower::AbstractConverter &converter,
+ const omp::ObjectList &objects,
+ llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
+ llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms) {
genObjectList(objects, converter, operands);
for (mlir::Value &operand : operands) {
checkMapType(operand.getLoc(), operand.getType());
@@ -177,25 +176,24 @@ addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
useDeviceLocs.push_back(operand.getLoc());
}
for (const omp::Object &object : objects)
- useDeviceSymbols.push_back(object.id());
+ useDeviceSyms.push_back(object.id());
}
static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
mlir::Location loc,
- llvm::SmallVectorImpl<mlir::Value> &lowerBound,
- llvm::SmallVectorImpl<mlir::Value> &upperBound,
- llvm::SmallVectorImpl<mlir::Value> &step,
+ mlir::omp::CollapseClauseOps &result,
std::size_t loopVarTypeSize) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
// The types of lower bound, upper bound, and step are converted into the
// type of the loop variable if necessary.
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
- for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) {
- lowerBound[it] =
- firOpBuilder.createConvert(loc, loopVarType, lowerBound[it]);
- upperBound[it] =
- firOpBuilder.createConvert(loc, loopVarType, upperBound[it]);
- step[it] = firOpBuilder.createConvert(loc, loopVarType, step[it]);
+ for (unsigned it = 0; it < (unsigned)result.loopLBVar.size(); it++) {
+ result.loopLBVar[it] =
+ firOpBuilder.createConvert(loc, loopVarType, result.loopLBVar[it]);
+ result.loopUBVar[it] =
+ firOpBuilder.createConvert(loc, loopVarType, result.loopUBVar[it]);
+ result.loopStepVar[it] =
+ firOpBuilder.createConvert(loc, loopVarType, result.loopStepVar[it]);
}
}
@@ -205,9 +203,7 @@ static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
bool ClauseProcessor::processCollapse(
mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
- llvm::SmallVectorImpl<mlir::Value> &lowerBound,
- llvm::SmallVectorImpl<mlir::Value> &upperBound,
- llvm::SmallVectorImpl<mlir::Value> &step,
+ mlir::omp::CollapseClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const {
bool found = false;
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -238,15 +234,15 @@ bool ClauseProcessor::processCollapse(
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
assert(bounds && "Expected bounds for worksharing do loop");
Fortran::lower::StatementContext stmtCtx;
- lowerBound.push_back(fir::getBase(converter.genExprValue(
+ result.loopLBVar.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
- upperBound.push_back(fir::getBase(converter.genExprValue(
+ result.loopUBVar.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
if (bounds->step) {
- step.push_back(fir::getBase(converter.genExprValue(
+ result.loopStepVar.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
} else { // If `step` is not present, assume it as `1`.
- step.push_back(firOpBuilder.createIntegerConstant(
+ result.loopStepVar.push_back(firOpBuilder.createIntegerConstant(
currentLocation, firOpBuilder.getIntegerType(32), 1));
}
iv.push_back(bounds->name.thing.symbol);
@@ -257,8 +253,7 @@ bool ClauseProcessor::processCollapse(
&*std::next(doConstructEval->getNestedEvaluations().begin());
} while (collapseValue > 0);
- convertLoopBounds(converter, currentLocation, lowerBound, upperBound, step,
- loopVarTypeSize);
+ convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
return found;
}
@@ -286,7 +281,7 @@ bool ClauseProcessor::processDefault() const {
}
bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+ mlir::omp::DeviceClauseOps &result) const {
const Fortran::parser::CharBlock *source = nullptr;
if (auto *clause = findUniqueClause<omp::clause::Device>(&source)) {
mlir::Location clauseLocation = converter.genLocation(*source);
@@ -298,25 +293,26 @@ bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
}
}
const auto &deviceExpr = std::get<omp::SomeExpr>(clause->t);
- result = fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
+ result.deviceVar =
+ fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processDeviceType(
- mlir::omp::DeclareTargetDeviceType &result) const {
+ mlir::omp::DeviceTypeClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::DeviceType>()) {
// Case: declare target ... device_type(any | host | nohost)
switch (clause->v) {
case omp::clause::DeviceType::DeviceTypeDescription::Nohost:
- result = mlir::omp::DeclareTargetDeviceType::nohost;
+ result.deviceType = mlir::omp::DeclareTargetDeviceType::nohost;
break;
case omp::clause::DeviceType::DeviceTypeDescription::Host:
- result = mlir::omp::DeclareTargetDeviceType::host;
+ result.deviceType = mlir::omp::DeclareTargetDeviceType::host;
break;
case omp::clause::DeviceType::DeviceTypeDescription::Any:
- result = mlir::omp::DeclareTargetDeviceType::any;
+ result.deviceType = mlir::omp::DeclareTargetDeviceType::any;
break;
}
return true;
@@ -325,7 +321,7 @@ bool ClauseProcessor::processDeviceType(
}
bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+ mlir::omp::FinalClauseOps &result) const {
const Fortran::parser::CharBlock *source = nullptr;
if (auto *clause = findUniqueClause<omp::clause::Final>(&source)) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -333,100 +329,108 @@ bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
mlir::Value finalVal =
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
- result = firOpBuilder.createConvert(clauseLocation,
- firOpBuilder.getI1Type(), finalVal);
+ result.finalVar = firOpBuilder.createConvert(
+ clauseLocation, firOpBuilder.getI1Type(), finalVal);
return true;
}
return false;
}
-bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Hint>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
int64_t hintValue = *Fortran::evaluate::ToInt64(clause->v);
- result = firOpBuilder.getI64IntegerAttr(hintValue);
+ result.hintAttr = firOpBuilder.getI64IntegerAttr(hintValue);
return true;
}
return false;
}
-bool ClauseProcessor::processMergeable(mlir::UnitAttr &result) const {
- return markClauseOccurrence<omp::clause::Mergeable>(result);
+bool ClauseProcessor::processMergeable(
+ mlir::omp::MergeableClauseOps &result) const {
+ return markClauseOccurrence<omp::clause::Mergeable>(result.mergeableAttr);
}
-bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const {
- return markClauseOccurrence<omp::clause::Nowait>(result);
+bool ClauseProcessor::processNowait(mlir::omp::NowaitClauseOps &result) const {
+ return markClauseOccurrence<omp::clause::Nowait>(result.nowaitAttr);
}
-bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+bool ClauseProcessor::processNumTeams(
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::NumTeamsClauseOps &result) const {
// TODO Get lower and upper bounds for num_teams when parser is updated to
// accept both.
if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) {
// auto lowerBound = std::get<std::optional<ExprTy>>(clause->t);
auto &upperBound = std::get<ExprTy>(clause->t);
- result = fir::getBase(converter.genExprValue(upperBound, stmtCtx));
+ result.numTeamsUpperVar =
+ fir::getBase(converter.genExprValue(upperBound, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processNumThreads(
- Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::NumThreadsClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
// OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
- result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.numThreadsVar =
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
-bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processOrdered(
+ mlir::omp::OrderedClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Ordered>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
int64_t orderedClauseValue = 0l;
if (clause->v.has_value())
orderedClauseValue = *Fortran::evaluate::ToInt64(*clause->v);
- result = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
+ result.orderedAttr = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
return true;
}
return false;
}
-bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+bool ClauseProcessor::processPriority(
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::PriorityClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Priority>()) {
- result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.priorityVar =
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processProcBind(
- mlir::omp::ClauseProcBindKindAttr &result) const {
+ mlir::omp::ProcBindClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::ProcBind>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- result = genProcBindKindAttr(firOpBuilder, *clause);
+ result.procBindKindAttr = genProcBindKindAttr(firOpBuilder, *clause);
return true;
}
return false;
}
-bool ClauseProcessor::processSafelen(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processSafelen(
+ mlir::omp::SafelenClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Safelen>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
const std::optional<std::int64_t> safelenVal =
Fortran::evaluate::ToInt64(clause->v);
- result = firOpBuilder.getI64IntegerAttr(*safelenVal);
+ result.safelenAttr = firOpBuilder.getI64IntegerAttr(*safelenVal);
return true;
}
return false;
}
bool ClauseProcessor::processSchedule(
- mlir::omp::ClauseScheduleKindAttr &valAttr,
- mlir::omp::ScheduleModifierAttr &modifierAttr,
- mlir::UnitAttr &simdModifierAttr) const {
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::ScheduleClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Schedule>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::MLIRContext *context = firOpBuilder.getContext();
@@ -451,53 +455,51 @@ bool ClauseProcessor::processSchedule(
break;
}
- mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause);
+ result.scheduleValAttr =
+ mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
+ mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause);
if (scheduleModifier != mlir::omp::ScheduleModifier::none)
- modifierAttr =
+ result.scheduleModAttr =
mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier);
if (getSimdModifier(*clause) != mlir::omp::ScheduleModifier::none)
- simdModifierAttr = firOpBuilder.getUnitAttr();
+ result.scheduleSimdAttr = firOpBuilder.getUnitAttr();
- valAttr = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processScheduleChunk(
- Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
- if (auto *clause = findUniqueClause<omp::clause::Schedule>()) {
if (const auto &chunkExpr = std::get<omp::MaybeExpr>(clause->t))
- result = fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
+ result.scheduleChunkVar =
+ fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
+
return true;
}
return false;
}
-bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processSimdlen(
+ mlir::omp::SimdlenClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Simdlen>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
const std::optional<std::int64_t> simdlenVal =
Fortran::evaluate::ToInt64(clause->v);
- result = firOpBuilder.getI64IntegerAttr(*simdlenVal);
+ result.simdlenAttr = firOpBuilder.getI64IntegerAttr(*simdlenVal);
return true;
}
return false;
}
bool ClauseProcessor::processThreadLimit(
- Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::ThreadLimitClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) {
- result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.threadLimitVar =
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
-bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
- return markClauseOccurrence<omp::clause::Untied>(result);
+bool ClauseProcessor::processUntied(mlir::omp::UntiedClauseOps &result) const {
+ return markClauseOccurrence<omp::clause::Untied>(result.untiedAttr);
}
//===----------------------------------------------------------------------===//
@@ -505,13 +507,12 @@ bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
//===----------------------------------------------------------------------===//
bool ClauseProcessor::processAllocate(
- llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
- llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const {
+ mlir::omp::AllocateClauseOps &result) const {
return findRepeatableClause<omp::clause::Allocate>(
[&](const omp::clause::Allocate &clause,
const Fortran::parser::CharBlock &) {
- genAllocateClause(converter, clause, allocatorOperands,
- allocateOperands);
+ genAllocateClause(converter, clause, result.allocatorVars,
+ result.allocateVars);
});
}
@@ -660,10 +661,9 @@ createCopyFunc(mlir::Location loc, Fortran::lower::AbstractConverter &converter,
return funcOp;
}
-bool ClauseProcessor::processCopyPrivate(
+bool ClauseProcessor::processCopyprivate(
mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> ©PrivateVars,
- llvm::SmallVectorImpl<mlir::Attribute> ©PrivateFuncs) const {
+ mlir::omp::CopyprivateClauseOps &result) const {
auto addCopyPrivateVar = [&](Fortran::semantics::Symbol *sym) {
mlir::Value symVal = converter.getSymbolAddress(*sym);
auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>();
@@ -690,10 +690,10 @@ bool ClauseProcessor::processCopyPrivate(
cpVar = alloca;
}
- copyPrivateVars.push_back(cpVar);
+ result.copyprivateVars.push_back(cpVar);
mlir::func::FuncOp funcOp =
createCopyFunc(currentLocation, converter, cpVar.getType(), attrs);
- copyPrivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
+ result.copyprivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
};
bool hasCopyPrivate = findRepeatableClause<clause::Copyprivate>(
@@ -714,9 +714,7 @@ bool ClauseProcessor::processCopyPrivate(
return hasCopyPrivate;
}
-bool ClauseProcessor::processDepend(
- llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
- llvm::SmallVectorImpl<mlir::Value> &dependOperands) const {
+bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
return findRepeatableClause<omp::clause::Depend>(
@@ -731,7 +729,7 @@ bool ClauseProcessor::processDepend(
mlir::omp::ClauseTaskDependAttr dependTypeOperand =
genDependKindAttr(firOpBuilder, kind);
- dependTypeOperands.append(objects.size(), dependTypeOperand);
+ result.dependTypeAttrs.append(objects.size(), dependTypeOperand);
for (const omp::Object &object : objects) {
assert(object.ref() && "Expecting designator");
@@ -746,14 +744,14 @@ bool ClauseProcessor::processDepend(
Fortran::semantics::Symbol *sym = object.id();
const mlir::Value variable = converter.getSymbolAddress(*sym);
- dependOperands.push_back(variable);
+ result.dependVars.push_back(variable);
}
});
}
bool ClauseProcessor::processIf(
omp::clause::If::DirectiveNameModifier directiveName,
- mlir::Value &result) const {
+ mlir::omp::IfClauseOps &result) const {
bool found = false;
findRepeatableClause<omp::clause::If>(
[&](const omp::clause::If &clause,
@@ -764,7 +762,7 @@ bool ClauseProcessor::processIf(
// Assume that, at most, a single 'if' clause will be applicable to the
// given directive.
if (operand) {
- result = operand;
+ result.ifVar = operand;
found = true;
}
});
@@ -807,12 +805,10 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
bool ClauseProcessor::processMap(
mlir::Location currentLocation, const llvm::omp::Directive &directive,
- Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands,
- llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
+ Fortran::lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
- const {
+ llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
return findRepeatableClause<omp::clause::Map>(
[&](const omp::clause::Map &clause,
@@ -887,25 +883,23 @@ bool ClauseProcessor::processMap(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
- mapOperands.push_back(mapOp);
- if (mapSymTypes)
- mapSymTypes->push_back(symAddr.getType());
+ result.mapVars.push_back(mapOp);
+
+ if (mapSyms)
+ mapSyms->push_back(object.id());
if (mapSymLocs)
mapSymLocs->push_back(symAddr.getLoc());
-
- if (mapSymbols)
- mapSymbols->push_back(object.id());
+ if (mapSymTypes)
+ mapSymTypes->push_back(symAddr.getType());
}
});
}
bool ClauseProcessor::processReduction(
- mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> &outReductionVars,
- llvm::SmallVectorImpl<mlir::Type> &outReductionTypes,
- llvm::SmallVectorImpl<mlir::Attribute> &outReductionDeclSymbols,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- *outReductionSymbols) const {
+ mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
+ llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *outReductionSyms)
+ const {
return findRepeatableClause<omp::clause::Reduction>(
[&](const omp::clause::Reduction &clause,
const Fortran::parser::CharBlock &) {
@@ -915,30 +909,31 @@ bool ClauseProcessor::processReduction(
// whether to do the reduction byref.
llvm::SmallVector<mlir::Value> reductionVars;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
- llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction(currentLocation, converter, clause,
reductionVars, reductionDeclSymbols,
- outReductionSymbols ? &reductionSymbols
- : nullptr);
+ outReductionSyms ? &reductionSyms : nullptr);
// Copy local lists into the output.
- llvm::copy(reductionVars, std::back_inserter(outReductionVars));
+ llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reductionDeclSymbols,
- std::back_inserter(outReductionDeclSymbols));
- if (outReductionSymbols)
- llvm::copy(reductionSymbols,
- std::back_inserter(*outReductionSymbols));
-
- outReductionTypes.reserve(outReductionTypes.size() +
- reductionVars.size());
- llvm::transform(reductionVars, std::back_inserter(outReductionTypes),
- [](mlir::Value v) { return v.getType(); });
+ std::back_inserter(result.reductionDeclSymbols));
+
+ if (outReductionTypes) {
+ outReductionTypes->reserve(outReductionTypes->size() +
+ reductionVars.size());
+ llvm::transform(reductionVars, std::back_inserter(*outReductionTypes),
+ [](mlir::Value v) { return v.getType(); });
+ }
+
+ if (outReductionSyms)
+ llvm::copy(reductionSyms, std::back_inserter(*outReductionSyms));
});
}
bool ClauseProcessor::processSectionsReduction(
- mlir::Location currentLocation) const {
+ mlir::Location currentLocation, mlir::omp::ReductionClauseOps &) const {
return findRepeatableClause<omp::clause::Reduction>(
[&](const omp::clause::Reduction &, const Fortran::parser::CharBlock &) {
TODO(currentLocation, "OMPC_Reduction");
@@ -967,30 +962,30 @@ bool ClauseProcessor::processEnter(
}
bool ClauseProcessor::processUseDeviceAddr(
- llvm::SmallVectorImpl<mlir::Value> &operands,
+ mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
const {
return findRepeatableClause<omp::clause::UseDeviceAddr>(
[&](const omp::clause::UseDeviceAddr &clause,
const Fortran::parser::CharBlock &) {
- addUseDeviceClause(converter, clause.v, operands, useDeviceTypes,
- useDeviceLocs, useDeviceSymbols);
+ addUseDeviceClause(converter, clause.v, result.useDeviceAddrVars,
+ useDeviceTypes, useDeviceLocs, useDeviceSyms);
});
}
bool ClauseProcessor::processUseDevicePtr(
- llvm::SmallVectorImpl<mlir::Value> &operands,
+ mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
const {
return findRepeatableClause<omp::clause::UseDevicePtr>(
[&](const omp::clause::UseDevicePtr &clause,
const Fortran::parser::CharBlock &) {
- addUseDeviceClause(converter, clause.v, operands, useDeviceTypes,
- useDeviceLocs, useDeviceSymbols);
+ addUseDeviceClause(converter, clause.v, result.useDevicePtrVars,
+ useDeviceTypes, useDeviceLocs, useDeviceSyms);
});
}
} // namespace omp
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index c0c603feb296af..d933e0a913d2bc 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -37,7 +37,7 @@ namespace omp {
/// corresponding clause if it is present in the clause list. Otherwise, they
/// will return `false` to signal that the clause was not found.
///
-/// The intended use is of this class is to move clause processing outside of
+/// The intended use of this class is to move clause processing outside of
/// construct processing, since the same clauses can appear attached to
/// different constructs and constructs can be combined, so that code
/// duplication is minimized.
@@ -56,94 +56,83 @@ class ClauseProcessor {
// 'Unique' clauses: They can appear at most once in the clause list.
bool processCollapse(
mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
- llvm::SmallVectorImpl<mlir::Value> &lowerBound,
- llvm::SmallVectorImpl<mlir::Value> &upperBound,
- llvm::SmallVectorImpl<mlir::Value> &step,
+ mlir::omp::CollapseClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const;
bool processDefault() const;
bool processDevice(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const;
+ mlir::omp::DeviceClauseOps &result) const;
+ bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
bool processFinal(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processHint(mlir::IntegerAttr &result) const;
- bool processMergeable(mlir::UnitAttr &result) const;
- bool processNowait(mlir::UnitAttr &result) const;
+ mlir::omp::FinalClauseOps &result) const;
+ bool processHint(mlir::omp::HintClauseOps &result) const;
+ bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
+ bool processNowait(mlir::omp::NowaitClauseOps &result) const;
bool processNumTeams(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
+ mlir::omp::NumTeamsClauseOps &result) const;
bool processNumThreads(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processOrdered(mlir::IntegerAttr &result) const;
+ mlir::omp::NumThreadsClauseOps &result) const;
+ bool processOrdered(mlir::omp::OrderedClauseOps &result) const;
bool processPriority(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const;
- bool processSafelen(mlir::IntegerAttr &result) const;
- bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr,
- mlir::omp::ScheduleModifierAttr &modifierAttr,
- mlir::UnitAttr &simdModifierAttr) const;
- bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processSimdlen(mlir::IntegerAttr &result) const;
+ mlir::omp::PriorityClauseOps &result) const;
+ bool processProcBind(mlir::omp::ProcBindClauseOps &result) const;
+ bool processSafelen(mlir::omp::SafelenClauseOps &result) const;
+ bool processSchedule(Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::ScheduleClauseOps &result) const;
+ bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const;
bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processUntied(mlir::UnitAttr &result) const;
+ mlir::omp::ThreadLimitClauseOps &result) const;
+ bool processUntied(mlir::omp::UntiedClauseOps &result) const;
// 'Repeatable' clauses: They can appear multiple times in the clause list.
- bool
- processAllocate(llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
- llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const;
+ bool processAllocate(mlir::omp::AllocateClauseOps &result) const;
bool processCopyin() const;
- bool processCopyPrivate(
- mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> ©PrivateVars,
- llvm::SmallVectorImpl<mlir::Attribute> ©PrivateFuncs) const;
- bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
- llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
+ bool processCopyprivate(mlir::Location currentLocation,
+ mlir::omp::CopyprivateClauseOps &result) const;
+ bool processDepend(mlir::omp::DependClauseOps &result) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
- mlir::Value &result) const;
+ mlir::omp::IfClauseOps &result) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
// This method is used to process a map clause.
- // The optional parameters - mapSymTypes, mapSymLocs & mapSymbols are used to
+ // The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to
// store the original type, location and Fortran symbol for the map operands.
// They may be used later on to create the block_arguments for some of the
// target directives that require it.
- bool processMap(mlir::Location currentLocation,
- const llvm::omp::Directive &directive,
- Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands,
- llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
- llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- *mapSymbols = nullptr) const;
- bool
- processReduction(mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
- llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- *reductionSymbols = nullptr) const;
- bool processSectionsReduction(mlir::Location currentLocation) const;
+ bool processMap(
+ mlir::Location currentLocation, const llvm::omp::Directive &directive,
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::MapClauseOps &result,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms =
+ nullptr,
+ llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
+ llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
+ bool processReduction(
+ mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
+ llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSyms =
+ nullptr) const;
+ bool processSectionsReduction(mlir::Location currentLocation,
+ mlir::omp::ReductionClauseOps &result) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool
- processUseDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
+ processUseDeviceAddr(mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- &useDeviceSymbols) const;
+ &useDeviceSyms) const;
bool
- processUseDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
+ processUseDevicePtr(mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- &useDeviceSymbols) const;
+ &useDeviceSyms) const;
template <typename T>
bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands);
+ mlir::omp::MapClauseOps &result);
// Call this method for these clauses that should be supported but are not
// implemented yet. It triggers a compilation error if any of the given
@@ -185,7 +174,7 @@ class ClauseProcessor {
template <typename T>
bool ClauseProcessor::processMotionClauses(
Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
+ mlir::omp::MapClauseOps &result) {
return findRepeatableClause<T>(
[&](const T &clause, const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
@@ -227,7 +216,7 @@ bool ClauseProcessor::processMotionClauses(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
- mapOperands.push_back(mapOp);
+ result.mapVars.push_back(mapOp);
}
});
}
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index e114ab9f4548ab..5a42e6a6aa4175 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -23,11 +23,13 @@ namespace Fortran {
namespace lower {
namespace omp {
-void DataSharingProcessor::processStep1() {
+void DataSharingProcessor::processStep1(
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
collectSymbolsForPrivatization();
collectDefaultSymbols();
- privatize();
- defaultPrivatize();
+ privatize(clauseOps, privateSyms);
+ defaultPrivatize(clauseOps, privateSyms);
insertBarrier();
}
@@ -299,14 +301,16 @@ void DataSharingProcessor::collectDefaultSymbols() {
}
}
-void DataSharingProcessor::privatize() {
+void DataSharingProcessor::privatize(
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
for (const Fortran::semantics::Symbol *sym : privatizedSymbols) {
if (const auto *commonDet =
sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
for (const auto &mem : commonDet->objects())
- doPrivatize(&*mem);
+ doPrivatize(&*mem, clauseOps, privateSyms);
} else
- doPrivatize(sym);
+ doPrivatize(sym, clauseOps, privateSyms);
}
}
@@ -323,7 +327,9 @@ void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) {
}
}
-void DataSharingProcessor::defaultPrivatize() {
+void DataSharingProcessor::defaultPrivatize(
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
for (const Fortran::semantics::Symbol *sym : defaultSymbols) {
if (!Fortran::semantics::IsProcedure(*sym) &&
!sym->GetUltimate().has<Fortran::semantics::DerivedTypeDetails>() &&
@@ -331,11 +337,14 @@ void DataSharingProcessor::defaultPrivatize() {
!symbolsInNestedRegions.contains(sym) &&
!symbolsInParentRegions.contains(sym) &&
!privatizedSymbols.contains(sym))
- doPrivatize(sym);
+ doPrivatize(sym, clauseOps, privateSyms);
}
}
-void DataSharingProcessor::doPrivatize(const Fortran::semantics::Symbol *sym) {
+void DataSharingProcessor::doPrivatize(
+ const Fortran::semantics::Symbol *sym,
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
if (!useDelayedPrivatization) {
cloneSymbol(sym);
copyFirstPrivateSymbol(sym);
@@ -419,10 +428,13 @@ void DataSharingProcessor::doPrivatize(const Fortran::semantics::Symbol *sym) {
return result;
}();
- delayedPrivatizationInfo.privatizers.push_back(
- mlir::SymbolRefAttr::get(privatizerOp));
- delayedPrivatizationInfo.originalAddresses.push_back(hsb.getAddr());
- delayedPrivatizationInfo.symbols.push_back(sym);
+ if (clauseOps) {
+ clauseOps->privatizers.push_back(mlir::SymbolRefAttr::get(privatizerOp));
+ clauseOps->privateVars.push_back(hsb.getAddr());
+ }
+
+ if (privateSyms)
+ privateSyms->push_back(sym);
}
} // namespace omp
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
index 226abe96705e35..9724b3d5ed02fe 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
@@ -19,28 +19,17 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
+namespace mlir {
+namespace omp {
+struct PrivateClauseOps;
+} // namespace omp
+} // namespace mlir
+
namespace Fortran {
namespace lower {
namespace omp {
class DataSharingProcessor {
-public:
- /// Collects all the information needed for delayed privatization. This can be
- /// used by ops with data-sharing clauses to properly generate their regions
- /// (e.g. add region arguments) and map the original SSA values to their
- /// corresponding OMP region operands.
- struct DelayedPrivatizationInfo {
- // The list of symbols referring to delayed privatizer ops (i.e.
- // `omp.private` ops).
- llvm::SmallVector<mlir::SymbolRefAttr> privatizers;
- // SSA values that correspond to "original" values being privatized.
- // "Original" here means the SSA value outside the OpenMP region from which
- // a clone is created inside the region.
- llvm::SmallVector<mlir::Value> originalAddresses;
- // Fortran symbols corresponding to the above SSA values.
- llvm::SmallVector<const Fortran::semantics::Symbol *> symbols;
- };
-
private:
bool hasLastPrivateOp;
mlir::OpBuilder::InsertPoint lastPrivIP;
@@ -57,7 +46,6 @@ class DataSharingProcessor {
Fortran::lower::pft::Evaluation &eval;
bool useDelayedPrivatization;
Fortran::lower::SymMap *symTable;
- DelayedPrivatizationInfo delayedPrivatizationInfo;
bool needBarrier();
void collectSymbols(Fortran::semantics::Symbol::Flag flag);
@@ -67,9 +55,16 @@ class DataSharingProcessor {
void collectSymbolsForPrivatization();
void insertBarrier();
void collectDefaultSymbols();
- void privatize();
- void defaultPrivatize();
- void doPrivatize(const Fortran::semantics::Symbol *sym);
+ void privatize(
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms);
+ void defaultPrivatize(
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms);
+ void doPrivatize(
+ const Fortran::semantics::Symbol *sym,
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms);
void copyLastPrivatize(mlir::Operation *op);
void insertLastPrivateCompare(mlir::Operation *op);
void cloneSymbol(const Fortran::semantics::Symbol *sym);
@@ -103,17 +98,15 @@ class DataSharingProcessor {
// Step2 performs the copying for lastprivates and requires knowledge of the
// MLIR operation to insert the last private update. Step2 adds
// dealocation code as well.
- void processStep1();
+ void processStep1(mlir::omp::PrivateClauseOps *clauseOps = nullptr,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ *privateSyms = nullptr);
void processStep2(mlir::Operation *op, bool isLoop);
void setLoopIV(mlir::Value iv) {
assert(!loopIV && "Loop iteration variable already set");
loopIV = iv;
}
-
- const DelayedPrivatizationInfo &getDelayedPrivatizationInfo() const {
- return delayedPrivatizationInfo;
- }
};
} // namespace omp
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 0cf2a8f97040a8..d67060d1cce72b 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -523,19 +523,25 @@ genMasterOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation) {
return genOpWithBody<mlir::omp::MasterOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
- .setGenNested(genNested),
- /*resultTypes=*/mlir::TypeRange());
+ .setGenNested(genNested));
}
static mlir::omp::OrderedRegionOp
genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location currentLocation) {
+ mlir::Location currentLocation,
+ const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::omp::OrderedRegionClauseOps clauseOps;
+
+ ClauseProcessor cp(converter, semaCtx, clauseList);
+ cp.processTODO<clause::Simd>(currentLocation,
+ llvm::omp::Directive::OMPD_ordered);
+
return genOpWithBody<mlir::omp::OrderedRegionOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested),
- /*simd=*/false);
+ clauseOps);
}
static mlir::omp::ParallelOp
@@ -546,77 +552,62 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList,
bool outerCombined = false) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
- mlir::Value ifClauseOperand, numThreadsClauseOperand;
- mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
- reductionVars;
+ mlir::omp::ParallelClauseOps clauseOps;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> privateSyms;
llvm::SmallVector<mlir::Type> reductionTypes;
- llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
- llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_parallel, ifClauseOperand);
- cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
- cp.processProcBind(procBindKindAttr);
+ cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
+ cp.processNumThreads(stmtCtx, clauseOps);
+ cp.processProcBind(clauseOps);
cp.processDefault();
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processAllocate(clauseOps);
+
if (!outerCombined)
- cp.processReduction(currentLocation, reductionVars, reductionTypes,
- reductionDeclSymbols, &reductionSymbols);
+ cp.processReduction(currentLocation, clauseOps, &reductionTypes,
+ &reductionSyms);
+
+ if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
+ clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
auto reductionCallback = [&](mlir::Operation *op) {
- llvm::SmallVector<mlir::Location> locs(reductionVars.size(),
+ llvm::SmallVector<mlir::Location> locs(clauseOps.reductionVars.size(),
currentLocation);
- auto *block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {},
- reductionTypes, locs);
+ auto *block =
+ firOpBuilder.createBlock(&op->getRegion(0), {}, reductionTypes, locs);
for (auto [arg, prv] :
- llvm::zip_equal(reductionSymbols, block->getArguments())) {
+ llvm::zip_equal(reductionSyms, block->getArguments())) {
converter.bindSymbol(*arg, prv);
}
- return reductionSymbols;
+ return reductionSyms;
};
- mlir::UnitAttr byrefAttr;
- if (ReductionProcessor::doReductionByRef(reductionVars))
- byrefAttr = converter.getFirOpBuilder().getUnitAttr();
-
OpWithBodyGenInfo genInfo =
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setOuterCombined(outerCombined)
.setClauses(&clauseList)
- .setReductions(&reductionSymbols, &reductionTypes)
+ .setReductions(&reductionSyms, &reductionTypes)
.setGenRegionEntryCb(reductionCallback);
- if (!enableDelayedPrivatization) {
- return genOpWithBody<mlir::omp::ParallelOp>(
- genInfo,
- /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
- numThreadsClauseOperand, allocateOperands, allocatorOperands,
- reductionVars,
- reductionDeclSymbols.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- reductionDeclSymbols),
- procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
- /*privatizers=*/nullptr, byrefAttr);
- }
+ if (!enableDelayedPrivatization)
+ return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
bool privatize = !outerCombined;
DataSharingProcessor dsp(converter, semaCtx, clauseList, eval,
/*useDelayedPrivatization=*/true, &symTable);
if (privatize)
- dsp.processStep1();
-
- const auto &delayedPrivatizationInfo = dsp.getDelayedPrivatizationInfo();
+ dsp.processStep1(&clauseOps, &privateSyms);
auto genRegionEntryCB = [&](mlir::Operation *op) {
auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
- llvm::SmallVector<mlir::Location> reductionLocs(reductionVars.size(),
- currentLocation);
+ llvm::SmallVector<mlir::Location> reductionLocs(
+ clauseOps.reductionVars.size(), currentLocation);
mlir::OperandRange privateVars = parallelOp.getPrivateVars();
mlir::Region ®ion = parallelOp.getRegion();
@@ -631,12 +622,12 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
llvm::transform(privateVars, std::back_inserter(privateVarLocs),
[](mlir::Value v) { return v.getLoc(); });
- converter.getFirOpBuilder().createBlock(®ion, /*insertPt=*/{},
- privateVarTypes, privateVarLocs);
+ firOpBuilder.createBlock(®ion, /*insertPt=*/{}, privateVarTypes,
+ privateVarLocs);
llvm::SmallVector<const Fortran::semantics::Symbol *> allSymbols =
- reductionSymbols;
- allSymbols.append(delayedPrivatizationInfo.symbols);
+ reductionSyms;
+ allSymbols.append(privateSyms);
for (auto [arg, prv] : llvm::zip_equal(allSymbols, region.getArguments())) {
converter.bindSymbol(*arg, prv);
}
@@ -646,26 +637,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
// TODO Merge with the reduction CB.
genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp);
-
- llvm::SmallVector<mlir::Attribute> privatizers(
- delayedPrivatizationInfo.privatizers.begin(),
- delayedPrivatizationInfo.privatizers.end());
-
- return genOpWithBody<mlir::omp::ParallelOp>(
- genInfo,
- /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
- numThreadsClauseOperand, allocateOperands, allocatorOperands,
- reductionVars,
- reductionDeclSymbols.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- reductionDeclSymbols),
- procBindKindAttr, delayedPrivatizationInfo.originalAddresses,
- delayedPrivatizationInfo.privatizers.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- privatizers),
- byrefAttr);
+ return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
}
static mlir::omp::SectionOp
@@ -689,28 +661,21 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &beginClauseList,
const Fortran::parser::OmpClauseList &endClauseList) {
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
- llvm::SmallVector<mlir::Value> copyPrivateVars;
- llvm::SmallVector<mlir::Attribute> copyPrivateFuncs;
- mlir::UnitAttr nowaitAttr;
+ mlir::omp::SingleClauseOps clauseOps;
ClauseProcessor cp(converter, semaCtx, beginClauseList);
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processAllocate(clauseOps);
+ // TODO Support delayed privatization.
ClauseProcessor ecp(converter, semaCtx, endClauseList);
- ecp.processNowait(nowaitAttr);
- ecp.processCopyPrivate(currentLocation, copyPrivateVars, copyPrivateFuncs);
+ ecp.processNowait(clauseOps);
+ ecp.processCopyprivate(currentLocation, clauseOps);
return genOpWithBody<mlir::omp::SingleOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setClauses(&beginClauseList),
- allocateOperands, allocatorOperands, copyPrivateVars,
- copyPrivateFuncs.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- copyPrivateFuncs),
- nowaitAttr);
+ clauseOps);
}
static mlir::omp::TaskOp
@@ -720,21 +685,19 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList) {
Fortran::lower::StatementContext stmtCtx;
- mlir::Value ifClauseOperand, finalClauseOperand, priorityClauseOperand;
- mlir::UnitAttr untiedAttr, mergeableAttr;
- llvm::SmallVector<mlir::Attribute> dependTypeOperands;
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
- dependOperands;
+ mlir::omp::TaskClauseOps clauseOps;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_task, ifClauseOperand);
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
+ cp.processAllocate(clauseOps);
cp.processDefault();
- cp.processFinal(stmtCtx, finalClauseOperand);
- cp.processUntied(untiedAttr);
- cp.processMergeable(mergeableAttr);
- cp.processPriority(stmtCtx, priorityClauseOperand);
- cp.processDepend(dependTypeOperands, dependOperands);
+ cp.processFinal(stmtCtx, clauseOps);
+ cp.processUntied(clauseOps);
+ cp.processMergeable(clauseOps);
+ cp.processPriority(stmtCtx, clauseOps);
+ cp.processDepend(clauseOps);
+ // TODO Support delayed privatization.
+
cp.processTODO<clause::InReduction, clause::Detach, clause::Affinity>(
currentLocation, llvm::omp::Directive::OMPD_task);
@@ -742,14 +705,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setClauses(&clauseList),
- ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr,
- /*in_reduction_vars=*/mlir::ValueRange(),
- /*in_reductions=*/nullptr, priorityClauseOperand,
- dependTypeOperands.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- dependTypeOperands),
- dependOperands, allocateOperands, allocatorOperands);
+ clauseOps);
}
static mlir::omp::TaskgroupOp
@@ -758,17 +714,18 @@ genTaskgroupOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval, bool genNested,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList) {
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
+ mlir::omp::TaskgroupClauseOps clauseOps;
+
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processAllocate(clauseOps);
cp.processTODO<clause::TaskReduction>(currentLocation,
llvm::omp::Directive::OMPD_taskgroup);
+
return genOpWithBody<mlir::omp::TaskgroupOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setClauses(&clauseList),
- /*task_reduction_vars=*/mlir::ValueRange(),
- /*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
+ clauseOps);
}
// This helper function implements the functionality of "promoting"
@@ -789,8 +746,7 @@ genTaskgroupOp(Fortran::lower::AbstractConverter &converter,
// clause. Support for such list items in a use_device_ptr clause
// is deprecated."
static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
- llvm::SmallVectorImpl<mlir::Value> &devicePtrOperands,
- llvm::SmallVectorImpl<mlir::Value> &deviceAddrOperands,
+ mlir::omp::UseDeviceClauseOps &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
@@ -803,9 +759,10 @@ static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
// Iterate over our use_device_ptr list and shift all non-cptr arguments into
// use_device_addr.
- for (auto *it = devicePtrOperands.begin(); it != devicePtrOperands.end();) {
+ for (auto *it = clauseOps.useDevicePtrVars.begin();
+ it != clauseOps.useDevicePtrVars.end();) {
if (!fir::isa_builtin_cptr_type(fir::unwrapRefType(it->getType()))) {
- deviceAddrOperands.push_back(*it);
+ clauseOps.useDeviceAddrVars.push_back(*it);
// We have to shuffle the symbols around as well, to maintain
// the correct Input -> BlockArg for use_device_ptr/use_device_addr.
// NOTE: However, as map's do not seem to be included currently
@@ -813,11 +770,11 @@ static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
// future alterations. I believe the reason they are not currently
// is that the BlockArg assign/lowering needs to be extended
// to a greater set of types.
- auto idx = std::distance(devicePtrOperands.begin(), it);
+ auto idx = std::distance(clauseOps.useDevicePtrVars.begin(), it);
moveElementToBack(idx, useDeviceTypes);
moveElementToBack(idx, useDeviceLocs);
moveElementToBack(idx, useDeviceSymbols);
- it = devicePtrOperands.erase(it);
+ it = clauseOps.useDevicePtrVars.erase(it);
continue;
}
++it;
@@ -831,20 +788,19 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList) {
Fortran::lower::StatementContext stmtCtx;
- mlir::Value ifClauseOperand, deviceOperand;
- llvm::SmallVector<mlir::Value> mapOperands, devicePtrOperands,
- deviceAddrOperands;
+ mlir::omp::TargetDataClauseOps clauseOps;
llvm::SmallVector<mlir::Type> useDeviceTypes;
llvm::SmallVector<mlir::Location> useDeviceLocs;
- llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSyms;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_target_data, ifClauseOperand);
- cp.processDevice(stmtCtx, deviceOperand);
- cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
- useDeviceSymbols);
- cp.processUseDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs,
- useDeviceSymbols);
+ cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps);
+ cp.processDevice(stmtCtx, clauseOps);
+ cp.processUseDevicePtr(clauseOps, useDeviceTypes, useDeviceLocs,
+ useDeviceSyms);
+ cp.processUseDeviceAddr(clauseOps, useDeviceTypes, useDeviceLocs,
+ useDeviceSyms);
+
// This function implements the deprecated functionality of use_device_ptr
// that allows users to provide non-CPTR arguments to it with the caveat
// that the compiler will treat them as use_device_addr. A lot of legacy
@@ -856,17 +812,16 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter,
// ordering.
// TODO: Perhaps create a user provideable compiler option that will
// re-introduce a hard-error rather than a warning in these cases.
- promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
- devicePtrOperands, deviceAddrOperands, useDeviceTypes, useDeviceLocs,
- useDeviceSymbols);
+ promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(clauseOps, useDeviceTypes,
+ useDeviceLocs, useDeviceSyms);
cp.processMap(currentLocation, llvm::omp::Directive::OMPD_target_data,
- stmtCtx, mapOperands);
+ stmtCtx, clauseOps);
auto dataOp = converter.getFirOpBuilder().create<mlir::omp::TargetDataOp>(
- currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
- deviceAddrOperands, mapOperands);
+ currentLocation, clauseOps);
+
genBodyOfTargetDataOp(converter, semaCtx, eval, genNested, dataOp,
- useDeviceTypes, useDeviceLocs, useDeviceSymbols,
+ useDeviceTypes, useDeviceLocs, useDeviceSyms,
currentLocation);
return dataOp;
}
@@ -879,10 +834,7 @@ static OpTy genTargetEnterExitDataUpdateOp(
const Fortran::parser::OmpClauseList &clauseList) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
- mlir::Value ifClauseOperand, deviceOperand;
- mlir::UnitAttr nowaitAttr;
- llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
- llvm::SmallVector<mlir::Attribute> dependTypeOperands;
+ mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps;
// GCC 9.3.0 emits a (probably) bogus warning about an unused variable.
[[maybe_unused]] llvm::omp::Directive directive;
@@ -897,25 +849,19 @@ static OpTy genTargetEnterExitDataUpdateOp(
}
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(directive, ifClauseOperand);
- cp.processDevice(stmtCtx, deviceOperand);
- cp.processDepend(dependTypeOperands, dependOperands);
- cp.processNowait(nowaitAttr);
+ cp.processIf(directive, clauseOps);
+ cp.processDevice(stmtCtx, clauseOps);
+ cp.processDepend(clauseOps);
+ cp.processNowait(clauseOps);
if constexpr (std::is_same_v<OpTy, mlir::omp::TargetUpdateOp>) {
- cp.processMotionClauses<clause::To>(stmtCtx, mapOperands);
- cp.processMotionClauses<clause::From>(stmtCtx, mapOperands);
+ cp.processMotionClauses<clause::To>(stmtCtx, clauseOps);
+ cp.processMotionClauses<clause::From>(stmtCtx, clauseOps);
} else {
- cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
+ cp.processMap(currentLocation, directive, stmtCtx, clauseOps);
}
- return firOpBuilder.create<OpTy>(
- currentLocation, ifClauseOperand, deviceOperand,
- dependTypeOperands.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- dependTypeOperands),
- dependOperands, nowaitAttr, mapOperands);
+ return firOpBuilder.create<OpTy>(currentLocation, clauseOps);
}
// This functions creates a block for the body of the targetOp's region. It adds
@@ -925,9 +871,9 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
mlir::omp::TargetOp &targetOp,
- llvm::ArrayRef<mlir::Type> mapSymTypes,
+ llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSyms,
llvm::ArrayRef<mlir::Location> mapSymLocs,
- llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSymbols,
+ llvm::ArrayRef<mlir::Type> mapSymTypes,
const mlir::Location ¤tLocation) {
assert(mapSymTypes.size() == mapSymLocs.size());
@@ -956,7 +902,7 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
};
// Bind the symbols to their corresponding block arguments.
- for (auto [argIndex, argSymbol] : llvm::enumerate(mapSymbols)) {
+ for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) {
const mlir::BlockArgument &arg = region.getArgument(argIndex);
// Avoid capture of a reference to a structured binding.
const Fortran::semantics::Symbol *sym = argSymbol;
@@ -1080,22 +1026,20 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpClauseList &clauseList,
llvm::omp::Directive directive, bool outerCombined = false) {
Fortran::lower::StatementContext stmtCtx;
- mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
- mlir::UnitAttr nowaitAttr;
- llvm::SmallVector<mlir::Attribute> dependTypeOperands;
- llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
- llvm::SmallVector<mlir::Type> mapSymTypes;
+ mlir::omp::TargetClauseOps clauseOps;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> mapSyms;
llvm::SmallVector<mlir::Location> mapSymLocs;
- llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
+ llvm::SmallVector<mlir::Type> mapSymTypes;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_target, ifClauseOperand);
- cp.processDevice(stmtCtx, deviceOperand);
- cp.processThreadLimit(stmtCtx, threadLimitOperand);
- cp.processDepend(dependTypeOperands, dependOperands);
- cp.processNowait(nowaitAttr);
- cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes,
- &mapSymLocs, &mapSymbols);
+ cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
+ cp.processDevice(stmtCtx, clauseOps);
+ cp.processThreadLimit(stmtCtx, clauseOps);
+ cp.processDepend(clauseOps);
+ cp.processNowait(clauseOps);
+ cp.processMap(currentLocation, directive, stmtCtx, clauseOps, &mapSyms,
+ &mapSymLocs, &mapSymTypes);
+ // TODO Support delayed privatization.
cp.processTODO<clause::Private, clause::Firstprivate, clause::IsDevicePtr,
clause::HasDeviceAddr, clause::Reduction, clause::InReduction,
@@ -1107,7 +1051,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
// symbols used inside the region that have not been explicitly mapped using
// the map clause.
auto captureImplicitMap = [&](const Fortran::semantics::Symbol &sym) {
- if (llvm::find(mapSymbols, &sym) == mapSymbols.end()) {
+ if (llvm::find(mapSyms, &sym) == mapSyms.end()) {
mlir::Value baseOp = converter.getSymbolAddress(sym);
if (!baseOp)
if (const auto *details = sym.template detailsIf<
@@ -1178,25 +1122,20 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
mapFlag),
captureKind, baseOp.getType());
- mapOperands.push_back(mapOp);
- mapSymTypes.push_back(baseOp.getType());
+ clauseOps.mapVars.push_back(mapOp);
+ mapSyms.push_back(&sym);
mapSymLocs.push_back(baseOp.getLoc());
- mapSymbols.push_back(&sym);
+ mapSymTypes.push_back(baseOp.getType());
}
}
};
Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap);
auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
- currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
- dependTypeOperands.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- dependTypeOperands),
- dependOperands, nowaitAttr, mapOperands);
+ currentLocation, clauseOps);
- genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
- mapSymLocs, mapSymbols, currentLocation);
+ genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSyms,
+ mapSymLocs, mapSymTypes, currentLocation);
return targetOp;
}
@@ -1209,17 +1148,16 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpClauseList &clauseList,
bool outerCombined = false) {
Fortran::lower::StatementContext stmtCtx;
- mlir::Value numTeamsClauseOperand, ifClauseOperand, threadLimitClauseOperand;
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
- reductionVars;
- llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
+ mlir::omp::TeamsClauseOps clauseOps;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_teams, ifClauseOperand);
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
+ cp.processAllocate(clauseOps);
cp.processDefault();
- cp.processNumTeams(stmtCtx, numTeamsClauseOperand);
- cp.processThreadLimit(stmtCtx, threadLimitClauseOperand);
+ cp.processNumTeams(stmtCtx, clauseOps);
+ cp.processThreadLimit(stmtCtx, clauseOps);
+ // TODO Support delayed privatization.
+
cp.processTODO<clause::Reduction>(currentLocation,
llvm::omp::Directive::OMPD_teams);
@@ -1228,30 +1166,20 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
.setGenNested(genNested)
.setOuterCombined(outerCombined)
.setClauses(&clauseList),
- /*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
- threadLimitClauseOperand, allocateOperands, allocatorOperands,
- reductionVars,
- reductionDeclSymbols.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- reductionDeclSymbols));
+ clauseOps);
}
/// Extract the list of function and variable symbols affected by the given
/// 'declare target' directive and return the intended device type for them.
-static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
+static void getDeclareTargetInfo(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
+ mlir::omp::DeclareTargetClauseOps &clauseOps,
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
-
- // The default capture type
- mlir::omp::DeclareTargetDeviceType deviceType =
- mlir::omp::DeclareTargetDeviceType::any;
const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
declareTargetConstruct.t);
-
if (const auto *objectList{
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
ObjectList objects{makeList(*objectList, semaCtx)};
@@ -1272,12 +1200,10 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
cp.processTo(symbolAndClause);
cp.processEnter(symbolAndClause);
cp.processLink(symbolAndClause);
- cp.processDeviceType(deviceType);
+ cp.processDeviceType(clauseOps);
cp.processTODO<clause::Indirect>(converter.getCurrentLocation(),
llvm::omp::Directive::OMPD_declare_target);
}
-
- return deviceType;
}
static void collectDeferredDeclareTargets(
@@ -1287,9 +1213,10 @@ static void collectDeferredDeclareTargets(
const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
llvm::SmallVectorImpl<Fortran::lower::OMPDeferredDeclareTargetInfo>
&deferredDeclareTarget) {
+ mlir::omp::DeclareTargetClauseOps clauseOps;
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
- mlir::omp::DeclareTargetDeviceType devType = getDeclareTargetInfo(
- converter, semaCtx, eval, declareTargetConstruct, symbolAndClause);
+ getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
+ clauseOps, symbolAndClause);
// Return the device type only if at least one of the targets for the
// directive is a function or subroutine
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
@@ -1299,8 +1226,9 @@ static void collectDeferredDeclareTargets(
std::get<const Fortran::semantics::Symbol &>(symClause)));
if (!op) {
- deferredDeclareTarget.push_back(
- {std::get<0>(symClause), devType, std::get<1>(symClause)});
+ deferredDeclareTarget.push_back({std::get<0>(symClause),
+ clauseOps.deviceType,
+ std::get<1>(symClause)});
}
}
}
@@ -1312,9 +1240,10 @@ getDeclareTargetFunctionDevice(
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclareTargetConstruct
&declareTargetConstruct) {
+ mlir::omp::DeclareTargetClauseOps clauseOps;
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
- mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
- converter, semaCtx, eval, declareTargetConstruct, symbolAndClause);
+ getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
+ clauseOps, symbolAndClause);
// Return the device type only if at least one of the targets for the
// directive is a function or subroutine
@@ -1324,7 +1253,7 @@ getDeclareTargetFunctionDevice(
std::get<const Fortran::semantics::Symbol &>(symClause)));
if (mlir::isa_and_nonnull<mlir::func::FuncOp>(op))
- return deviceType;
+ return clauseOps.deviceType;
}
return std::nullopt;
@@ -1354,12 +1283,14 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_barrier:
firOpBuilder.create<mlir::omp::BarrierOp>(currentLocation);
break;
- case llvm::omp::Directive::OMPD_taskwait:
- ClauseProcessor(converter, semaCtx, opClauseList)
- .processTODO<clause::Depend, clause::Nowait>(
- currentLocation, llvm::omp::Directive::OMPD_taskwait);
- firOpBuilder.create<mlir::omp::TaskwaitOp>(currentLocation);
+ case llvm::omp::Directive::OMPD_taskwait: {
+ mlir::omp::TaskwaitClauseOps clauseOps;
+ ClauseProcessor cp(converter, semaCtx, opClauseList);
+ cp.processTODO<clause::Depend, clause::Nowait>(
+ currentLocation, llvm::omp::Directive::OMPD_taskwait);
+ firOpBuilder.create<mlir::omp::TaskwaitOp>(currentLocation, clauseOps);
break;
+ }
case llvm::omp::Directive::OMPD_taskyield:
firOpBuilder.create<mlir::omp::TaskyieldOp>(currentLocation);
break;
@@ -1494,32 +1425,21 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
- mlir::Value scheduleChunkClauseOperand, ifClauseOperand;
- llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, reductionVars;
- llvm::SmallVector<mlir::Value> alignedVars, nontemporalVars;
+ mlir::omp::SimdLoopClauseOps clauseOps;
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
- llvm::SmallVector<mlir::Type> reductionTypes;
- llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
- mlir::omp::ClauseOrderKindAttr orderClauseOperand;
- mlir::IntegerAttr simdlenClauseOperand, safelenClauseOperand;
ClauseProcessor cp(converter, semaCtx, loopOpClauseList);
- cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv);
- cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
- cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols);
- cp.processIf(llvm::omp::Directive::OMPD_simd, ifClauseOperand);
- cp.processSimdlen(simdlenClauseOperand);
- cp.processSafelen(safelenClauseOperand);
+ cp.processCollapse(loc, eval, clauseOps, iv);
+ cp.processReduction(loc, clauseOps);
+ cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps);
+ cp.processSimdlen(clauseOps);
+ cp.processSafelen(clauseOps);
+ clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr();
+ // TODO Support delayed privatization.
+
cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
clause::Nontemporal, clause::Order>(loc, ompDirective);
- mlir::TypeRange resultType;
- auto simdLoopOp = firOpBuilder.create<mlir::omp::SimdLoopOp>(
- loc, resultType, lowerBound, upperBound, step, alignedVars,
- /*alignment_values=*/nullptr, ifClauseOperand, nontemporalVars,
- orderClauseOperand, simdlenClauseOperand, safelenClauseOperand,
- /*inclusive=*/firOpBuilder.getUnitAttr());
-
auto *nestedEval = getCollapsedLoopEval(
eval, Fortran::lower::getCollapseValue(loopOpClauseList));
@@ -1527,11 +1447,12 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
return genLoopVars(op, converter, loc, iv);
};
- createBodyOfOp<mlir::omp::SimdLoopOp>(
- simdLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
- .setClauses(&loopOpClauseList)
- .setDataSharingProcessor(&dsp)
- .setGenRegionEntryCb(ivCallback));
+ genOpWithBody<mlir::omp::SimdLoopOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
+ .setClauses(&loopOpClauseList)
+ .setDataSharingProcessor(&dsp)
+ .setGenRegionEntryCb(ivCallback),
+ clauseOps);
}
static void createWsloop(Fortran::lower::AbstractConverter &converter,
@@ -1546,77 +1467,50 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter,
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
- mlir::Value scheduleChunkClauseOperand;
- llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, reductionVars;
- llvm::SmallVector<mlir::Value> linearVars, linearStepVars;
+ mlir::omp::WsloopClauseOps clauseOps;
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
llvm::SmallVector<mlir::Type> reductionTypes;
- llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
- llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
- mlir::omp::ClauseOrderKindAttr orderClauseOperand;
- mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
- mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand;
- mlir::IntegerAttr orderedClauseOperand;
- mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
ClauseProcessor cp(converter, semaCtx, beginClauseList);
- cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv);
- cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
- cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols,
- &reductionSymbols);
- cp.processTODO<clause::Linear, clause::Order>(loc, ompDirective);
-
- if (ReductionProcessor::doReductionByRef(reductionVars))
- byrefOperand = firOpBuilder.getUnitAttr();
-
- auto wsLoopOp = firOpBuilder.create<mlir::omp::WsloopOp>(
- loc, lowerBound, upperBound, step, linearVars, linearStepVars,
- reductionVars,
- reductionDeclSymbols.empty()
- ? nullptr
- : mlir::ArrayAttr::get(firOpBuilder.getContext(),
- reductionDeclSymbols),
- scheduleValClauseOperand, scheduleChunkClauseOperand,
- /*schedule_modifiers=*/nullptr,
- /*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand,
- orderedClauseOperand, orderClauseOperand,
- /*inclusive=*/firOpBuilder.getUnitAttr());
-
- // Handle attribute based clauses.
- if (cp.processOrdered(orderedClauseOperand))
- wsLoopOp.setOrderedValAttr(orderedClauseOperand);
-
- if (cp.processSchedule(scheduleValClauseOperand, scheduleModClauseOperand,
- scheduleSimdClauseOperand)) {
- wsLoopOp.setScheduleValAttr(scheduleValClauseOperand);
- wsLoopOp.setScheduleModifierAttr(scheduleModClauseOperand);
- wsLoopOp.setSimdModifierAttr(scheduleSimdClauseOperand);
- }
+ cp.processCollapse(loc, eval, clauseOps, iv);
+ cp.processSchedule(stmtCtx, clauseOps);
+ cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
+ cp.processOrdered(clauseOps);
+ clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr();
+ // TODO Support delayed privatization.
+
+ if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
+ clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
+
+ cp.processTODO<clause::Allocate, clause::Linear, clause::Order>(loc,
+ ompDirective);
+
// In FORTRAN `nowait` clause occur at the end of `omp do` directive.
// i.e
// !$omp do
// <...>
// !$omp end do nowait
if (endClauseList) {
- if (ClauseProcessor(converter, semaCtx, *endClauseList)
- .processNowait(nowaitClauseOperand))
- wsLoopOp.setNowaitAttr(nowaitClauseOperand);
+ ClauseProcessor ecp(converter, semaCtx, *endClauseList);
+ ecp.processNowait(clauseOps);
}
auto *nestedEval = getCollapsedLoopEval(
eval, Fortran::lower::getCollapseValue(beginClauseList));
auto ivCallback = [&](mlir::Operation *op) {
- return genLoopAndReductionVars(op, converter, loc, iv, reductionSymbols,
+ return genLoopAndReductionVars(op, converter, loc, iv, reductionSyms,
reductionTypes);
};
- createBodyOfOp<mlir::omp::WsloopOp>(
- wsLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
- .setClauses(&beginClauseList)
- .setDataSharingProcessor(&dsp)
- .setReductions(&reductionSymbols, &reductionTypes)
- .setGenRegionEntryCb(ivCallback));
+ genOpWithBody<mlir::omp::WsloopOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
+ .setClauses(&beginClauseList)
+ .setDataSharingProcessor(&dsp)
+ .setReductions(&reductionSyms, &reductionTypes)
+ .setGenRegionEntryCb(ivCallback),
+ clauseOps);
}
static void createSimdWsloop(
@@ -1704,10 +1598,11 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclareTargetConstruct
&declareTargetConstruct) {
+ mlir::omp::DeclareTargetClauseOps clauseOps;
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
- mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
- converter, semaCtx, eval, declareTargetConstruct, symbolAndClause);
+ getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
+ clauseOps, symbolAndClause);
for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
mlir::Operation *op = mod.lookupSymbol(converter.mangleName(
@@ -1721,7 +1616,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
markDeclareTarget(
op, converter,
- std::get<mlir::omp::DeclareTargetCaptureClause>(symClause), deviceType);
+ std::get<mlir::omp::DeclareTargetCaptureClause>(symClause),
+ clauseOps.deviceType);
}
}
@@ -1853,7 +1749,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
!std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u)) {
+ !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u) &&
+ !std::get_if<Fortran::parser::OmpClause::Simd>(&clause.u)) {
TODO(clauseLocation, "OpenMP Block construct clause");
}
}
@@ -1873,7 +1770,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
break;
case llvm::omp::Directive::OMPD_ordered:
genOrderedRegionOp(converter, semaCtx, eval, /*genNested=*/true,
- currentLocation);
+ currentLocation, beginClauseList);
break;
case llvm::omp::Directive::OMPD_parallel:
genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/true,
@@ -1964,7 +1861,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();
- mlir::IntegerAttr hintClauseOp;
std::string name;
const Fortran::parser::OmpCriticalDirective &cd =
std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t);
@@ -1973,21 +1869,28 @@ genOMP(Fortran::lower::AbstractConverter &converter,
std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString();
}
- const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
- ClauseProcessor(converter, semaCtx, clauseList).processHint(hintClauseOp);
-
mlir::omp::CriticalOp criticalOp = [&]() {
if (name.empty()) {
return firOpBuilder.create<mlir::omp::CriticalOp>(
currentLocation, mlir::FlatSymbolRefAttr());
}
+
mlir::ModuleOp module = firOpBuilder.getModule();
mlir::OpBuilder modBuilder(module.getBodyRegion());
auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name);
- if (!global)
- global = modBuilder.create<mlir::omp::CriticalDeclareOp>(
- currentLocation,
- mlir::StringAttr::get(firOpBuilder.getContext(), name), hintClauseOp);
+ if (!global) {
+ mlir::omp::CriticalClauseOps clauseOps;
+ const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
+
+ ClauseProcessor cp(converter, semaCtx, clauseList);
+ cp.processHint(clauseOps);
+ clauseOps.nameAttr =
+ mlir::StringAttr::get(firOpBuilder.getContext(), name);
+
+ global = modBuilder.create<mlir::omp::CriticalDeclareOp>(currentLocation,
+ clauseOps);
+ }
+
return firOpBuilder.create<mlir::omp::CriticalOp>(
currentLocation, mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(),
global.getSymName()));
@@ -2104,8 +2007,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) {
mlir::Location currentLocation = converter.getCurrentLocation();
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
- mlir::UnitAttr nowaitClauseOperand;
+ mlir::omp::SectionsClauseOps clauseOps;
const auto &beginSectionsDirective =
std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t);
const auto §ionsClauseList =
@@ -2114,8 +2016,9 @@ genOMP(Fortran::lower::AbstractConverter &converter,
// Process clauses before optional omp.parallel, so that new variables are
// allocated outside of the parallel region
ClauseProcessor cp(converter, semaCtx, sectionsClauseList);
- cp.processSectionsReduction(currentLocation);
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processSectionsReduction(currentLocation, clauseOps);
+ cp.processAllocate(clauseOps);
+ // TODO Support delayed privatization.
llvm::omp::Directive dir =
std::get<Fortran::parser::OmpSectionsDirective>(beginSectionsDirective.t)
@@ -2132,16 +2035,14 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const auto &endSectionsClauseList =
std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t);
ClauseProcessor(converter, semaCtx, endSectionsClauseList)
- .processNowait(nowaitClauseOperand);
+ .processNowait(clauseOps);
}
// SECTIONS construct
genOpWithBody<mlir::omp::SectionsOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(false),
- /*reduction_vars=*/mlir::ValueRange(),
- /*reductions=*/nullptr, allocateOperands, allocatorOperands,
- nowaitClauseOperand);
+ clauseOps);
const auto §ionBlocks =
std::get<Fortran::parser::OmpSectionBlocks>(sectionsConstruct.t);
>From e291fad68b78d28bfa73caab94ddcb978db2a602 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 28 Mar 2024 15:14:37 +0000
Subject: [PATCH 03/17] [Flang][OpenMP][Lower] Split MLIR codegen for clauses
and constructs
This patch performs several cleanups with the main purpose of normalizing the
code patterns used to trigger codegen for MLIR OpenMP operations and making the
processing of clauses and constructs independent. The following changes are
made:
- Clean up unused `directive` argument to `ClauseProcessor::processMap()`.
- Move general helper functions in OpenMP.cpp to the appropriate section of the
file.
- Create `gen<OpName>Clauses()` functions containing the clause processing code
specific for the associated OpenMP construct.
- Update `gen<OpName>Op()` functions to call the corresponding
`gen<OpName>Clauses()` function.
- Sort calls to `ClauseProcessor::process<ClauseName>()` alphabetically, to
avoid inadvertently relying on some arbitrary order. Update some tests that
broke due to the order change.
- Normalize `genOMP()` functions so they all delegate the generation of MLIR to
`gen<OpName>Op()` functions following the same pattern.
- Only process `nowait` clause on `TARGET` constructs if not compiling for the
target device.
A later patch can move the calls to `gen<OpName>Clauses()` out of
`gen<OpName>Op()` functions and passing completed clause structures instead, in
preparation to supporting composite constructs. That will make it possible to
reuse clause processing for a given leaf construct when appearing alone or in a
combined or composite construct, while controlling where the associated code is
produced.
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 4 +-
flang/lib/Lower/OpenMP/ClauseProcessor.h | 3 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 2090 +++++++++--------
flang/test/Lower/OpenMP/FIR/target.f90 | 2 +-
flang/test/Lower/OpenMP/target.f90 | 2 +-
.../use-device-ptr-to-use-device-addr.f90 | 4 +-
6 files changed, 1173 insertions(+), 932 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index ee1f6c2fbc7e89..e2b26b3025049f 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -804,8 +804,8 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
}
bool ClauseProcessor::processMap(
- mlir::Location currentLocation, const llvm::omp::Directive &directive,
- Fortran::lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result,
+ mlir::Location currentLocation, Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::MapClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index d933e0a913d2bc..9e59d754280ef4 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -102,8 +102,7 @@ class ClauseProcessor {
// They may be used later on to create the block_arguments for some of the
// target directives that require it.
bool processMap(
- mlir::Location currentLocation, const llvm::omp::Directive &directive,
- Fortran::lower::StatementContext &stmtCtx,
+ mlir::Location currentLocation, Fortran::lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms =
nullptr,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index d67060d1cce72b..b6de2079a973f5 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -237,6 +237,276 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
return storeOp;
}
+// This helper function implements the functionality of "promoting"
+// non-CPTR arguments of use_device_ptr to use_device_addr
+// arguments (automagic conversion of use_device_ptr ->
+// use_device_addr in these cases). The way we do so currently is
+// through the shuffling of operands from the devicePtrOperands to
+// deviceAddrOperands where neccesary and re-organizing the types,
+// locations and symbols to maintain the correct ordering of ptr/addr
+// input -> BlockArg.
+//
+// This effectively implements some deprecated OpenMP functionality
+// that some legacy applications unfortunately depend on
+// (deprecated in specification version 5.2):
+//
+// "If a list item in a use_device_ptr clause is not of type C_PTR,
+// the behavior is as if the list item appeared in a use_device_addr
+// clause. Support for such list items in a use_device_ptr clause
+// is deprecated."
+static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
+ mlir::omp::UseDeviceClauseOps &clauseOps,
+ llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
+ llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ &useDeviceSymbols) {
+ auto moveElementToBack = [](size_t idx, auto &vector) {
+ auto *iter = std::next(vector.begin(), idx);
+ vector.push_back(*iter);
+ vector.erase(iter);
+ };
+
+ // Iterate over our use_device_ptr list and shift all non-cptr arguments into
+ // use_device_addr.
+ for (auto *it = clauseOps.useDevicePtrVars.begin();
+ it != clauseOps.useDevicePtrVars.end();) {
+ if (!fir::isa_builtin_cptr_type(fir::unwrapRefType(it->getType()))) {
+ clauseOps.useDeviceAddrVars.push_back(*it);
+ // We have to shuffle the symbols around as well, to maintain
+ // the correct Input -> BlockArg for use_device_ptr/use_device_addr.
+ // NOTE: However, as map's do not seem to be included currently
+ // this isn't as pertinent, but we must try to maintain for
+ // future alterations. I believe the reason they are not currently
+ // is that the BlockArg assign/lowering needs to be extended
+ // to a greater set of types.
+ auto idx = std::distance(clauseOps.useDevicePtrVars.begin(), it);
+ moveElementToBack(idx, useDeviceTypes);
+ moveElementToBack(idx, useDeviceLocs);
+ moveElementToBack(idx, useDeviceSymbols);
+ it = clauseOps.useDevicePtrVars.erase(it);
+ continue;
+ }
+ ++it;
+ }
+}
+
+/// Extract the list of function and variable symbols affected by the given
+/// 'declare target' directive and return the intended device type for them.
+static void getDeclareTargetInfo(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
+ mlir::omp::DeclareTargetClauseOps &clauseOps,
+ llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
+ const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
+ declareTargetConstruct.t);
+ if (const auto *objectList{
+ Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
+ ObjectList objects{makeList(*objectList, semaCtx)};
+ // Case: declare target(func, var1, var2)
+ gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to,
+ symbolAndClause);
+ } else if (const auto *clauseList{
+ Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
+ spec.u)}) {
+ if (clauseList->v.empty()) {
+ // Case: declare target, implicit capture of function
+ symbolAndClause.emplace_back(
+ mlir::omp::DeclareTargetCaptureClause::to,
+ eval.getOwningProcedure()->getSubprogramSymbol());
+ }
+
+ ClauseProcessor cp(converter, semaCtx, *clauseList);
+ cp.processDeviceType(clauseOps);
+ cp.processEnter(symbolAndClause);
+ cp.processLink(symbolAndClause);
+ cp.processTo(symbolAndClause);
+ cp.processTODO<clause::Indirect>(converter.getCurrentLocation(),
+ llvm::omp::Directive::OMPD_declare_target);
+ }
+}
+
+static void collectDeferredDeclareTargets(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
+ llvm::SmallVectorImpl<Fortran::lower::OMPDeferredDeclareTargetInfo>
+ &deferredDeclareTarget) {
+ mlir::omp::DeclareTargetClauseOps clauseOps;
+ llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
+ getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
+ clauseOps, symbolAndClause);
+ // Return the device type only if at least one of the targets for the
+ // directive is a function or subroutine
+ mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
+
+ for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
+ mlir::Operation *op = mod.lookupSymbol(converter.mangleName(
+ std::get<const Fortran::semantics::Symbol &>(symClause)));
+
+ if (!op) {
+ deferredDeclareTarget.push_back({std::get<0>(symClause),
+ clauseOps.deviceType,
+ std::get<1>(symClause)});
+ }
+ }
+}
+
+static std::optional<mlir::omp::DeclareTargetDeviceType>
+getDeclareTargetFunctionDevice(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenMPDeclareTargetConstruct
+ &declareTargetConstruct) {
+ mlir::omp::DeclareTargetClauseOps clauseOps;
+ llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
+ getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
+ clauseOps, symbolAndClause);
+
+ // Return the device type only if at least one of the targets for the
+ // directive is a function or subroutine
+ mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
+ for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
+ mlir::Operation *op = mod.lookupSymbol(converter.mangleName(
+ std::get<const Fortran::semantics::Symbol &>(symClause)));
+
+ if (mlir::isa_and_nonnull<mlir::func::FuncOp>(op))
+ return clauseOps.deviceType;
+ }
+
+ return std::nullopt;
+}
+
+static llvm::SmallVector<const Fortran::semantics::Symbol *>
+genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
+ mlir::Location &loc,
+ llvm::ArrayRef<const Fortran::semantics::Symbol *> args) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ auto ®ion = op->getRegion(0);
+
+ std::size_t loopVarTypeSize = 0;
+ for (const Fortran::semantics::Symbol *arg : args)
+ loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
+ mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
+ llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
+ llvm::SmallVector<mlir::Location> locs(args.size(), loc);
+ firOpBuilder.createBlock(®ion, {}, tiv, locs);
+ // The argument is not currently in memory, so make a temporary for the
+ // argument, and store it there, then bind that location to the argument.
+ mlir::Operation *storeOp = nullptr;
+ for (auto [argIndex, argSymbol] : llvm::enumerate(args)) {
+ mlir::Value indexVal = fir::getBase(region.front().getArgument(argIndex));
+ storeOp =
+ createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
+ }
+ firOpBuilder.setInsertionPointAfter(storeOp);
+
+ return llvm::SmallVector<const Fortran::semantics::Symbol *>(args);
+}
+
+static void genReductionVars(
+ mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
+ mlir::Location &loc,
+ llvm::ArrayRef<const Fortran::semantics::Symbol *> reductionArgs,
+ llvm::ArrayRef<mlir::Type> reductionTypes) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ llvm::SmallVector<mlir::Location> blockArgLocs(reductionArgs.size(), loc);
+
+ mlir::Block *entryBlock = firOpBuilder.createBlock(
+ &op->getRegion(0), {}, reductionTypes, blockArgLocs);
+
+ // Bind the reduction arguments to their block arguments.
+ for (auto [arg, prv] :
+ llvm::zip_equal(reductionArgs, entryBlock->getArguments())) {
+ converter.bindSymbol(*arg, prv);
+ }
+}
+
+static llvm::SmallVector<const Fortran::semantics::Symbol *>
+genLoopAndReductionVars(
+ mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
+ mlir::Location &loc,
+ llvm::ArrayRef<const Fortran::semantics::Symbol *> loopArgs,
+ llvm::ArrayRef<const Fortran::semantics::Symbol *> reductionArgs,
+ llvm::ArrayRef<mlir::Type> reductionTypes) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+ llvm::SmallVector<mlir::Type> blockArgTypes;
+ llvm::SmallVector<mlir::Location> blockArgLocs;
+ blockArgTypes.reserve(loopArgs.size() + reductionArgs.size());
+ blockArgLocs.reserve(blockArgTypes.size());
+ mlir::Block *entryBlock;
+
+ if (loopArgs.size()) {
+ std::size_t loopVarTypeSize = 0;
+ for (const Fortran::semantics::Symbol *arg : loopArgs)
+ loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
+ mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
+ std::fill_n(std::back_inserter(blockArgTypes), loopArgs.size(),
+ loopVarType);
+ std::fill_n(std::back_inserter(blockArgLocs), loopArgs.size(), loc);
+ }
+ if (reductionArgs.size()) {
+ llvm::copy(reductionTypes, std::back_inserter(blockArgTypes));
+ std::fill_n(std::back_inserter(blockArgLocs), reductionArgs.size(), loc);
+ }
+ entryBlock = firOpBuilder.createBlock(&op->getRegion(0), {}, blockArgTypes,
+ blockArgLocs);
+ // The argument is not currently in memory, so make a temporary for the
+ // argument, and store it there, then bind that location to the argument.
+ if (loopArgs.size()) {
+ mlir::Operation *storeOp = nullptr;
+ for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) {
+ mlir::Value indexVal =
+ fir::getBase(op->getRegion(0).front().getArgument(argIndex));
+ storeOp =
+ createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
+ }
+ firOpBuilder.setInsertionPointAfter(storeOp);
+ }
+ // Bind the reduction arguments to their block arguments
+ for (auto [arg, prv] : llvm::zip_equal(
+ reductionArgs,
+ llvm::drop_begin(entryBlock->getArguments(), loopArgs.size()))) {
+ converter.bindSymbol(*arg, prv);
+ }
+
+ return llvm::SmallVector<const Fortran::semantics::Symbol *>(loopArgs);
+}
+
+static void
+markDeclareTarget(mlir::Operation *op,
+ Fortran::lower::AbstractConverter &converter,
+ mlir::omp::DeclareTargetCaptureClause captureClause,
+ mlir::omp::DeclareTargetDeviceType deviceType) {
+ // TODO: Add support for program local variables with declare target applied
+ auto declareTargetOp = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(op);
+ if (!declareTargetOp)
+ fir::emitFatalError(
+ converter.getCurrentLocation(),
+ "Attempt to apply declare target on unsupported operation");
+
+ // The function or global already has a declare target applied to it, very
+ // likely through implicit capture (usage in another declare target
+ // function/subroutine). It should be marked as any if it has been assigned
+ // both host and nohost, else we skip, as there is no change
+ if (declareTargetOp.isDeclareTarget()) {
+ if (declareTargetOp.getDeclareTargetDeviceType() != deviceType)
+ declareTargetOp.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any,
+ captureClause);
+ return;
+ }
+
+ declareTargetOp.setDeclareTarget(deviceType, captureClause);
+}
+
+//===----------------------------------------------------------------------===//
+// Op body generation helper structures and functions
+//===----------------------------------------------------------------------===//
+
struct OpWithBodyGenInfo {
/// A type for a code-gen callback function. This takes as argument the op for
/// which the code is being generated and returns the arguments of the op's
@@ -508,543 +778,726 @@ static void genBodyOfTargetDataOp(
genNestedEvaluations(converter, eval);
}
-template <typename OpTy, typename... Args>
-static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) {
- auto op = info.converter.getFirOpBuilder().create<OpTy>(
- info.loc, std::forward<Args>(args)...);
- createBodyOfOp<OpTy>(op, info);
- return op;
-}
-
-static mlir::omp::MasterOp
-genMasterOp(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location currentLocation) {
- return genOpWithBody<mlir::omp::MasterOp>(
- OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
- .setGenNested(genNested));
-}
-
-static mlir::omp::OrderedRegionOp
-genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location currentLocation,
- const Fortran::parser::OmpClauseList &clauseList) {
- mlir::omp::OrderedRegionClauseOps clauseOps;
+// This functions creates a block for the body of the targetOp's region. It adds
+// all the symbols present in mapSymbols as block arguments to this block.
+static void
+genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, bool genNested,
+ mlir::omp::TargetOp &targetOp,
+ llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSyms,
+ llvm::ArrayRef<mlir::Location> mapSymLocs,
+ llvm::ArrayRef<mlir::Type> mapSymTypes,
+ const mlir::Location ¤tLocation) {
+ assert(mapSymTypes.size() == mapSymLocs.size());
- ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processTODO<clause::Simd>(currentLocation,
- llvm::omp::Directive::OMPD_ordered);
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::Region ®ion = targetOp.getRegion();
- return genOpWithBody<mlir::omp::OrderedRegionOp>(
- OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
- .setGenNested(genNested),
- clauseOps);
-}
+ auto *regionBlock =
+ firOpBuilder.createBlock(®ion, {}, mapSymTypes, mapSymLocs);
-static mlir::omp::ParallelOp
-genParallelOp(Fortran::lower::AbstractConverter &converter,
- Fortran::lower::SymMap &symTable,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location currentLocation,
- const Fortran::parser::OmpClauseList &clauseList,
- bool outerCombined = false) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- Fortran::lower::StatementContext stmtCtx;
- mlir::omp::ParallelClauseOps clauseOps;
- llvm::SmallVector<const Fortran::semantics::Symbol *> privateSyms;
- llvm::SmallVector<mlir::Type> reductionTypes;
- llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
+ // Clones the `bounds` placing them inside the target region and returns them.
+ auto cloneBound = [&](mlir::Value bound) {
+ if (mlir::isMemoryEffectFree(bound.getDefiningOp())) {
+ mlir::Operation *clonedOp = bound.getDefiningOp()->clone();
+ regionBlock->push_back(clonedOp);
+ return clonedOp->getResult(0);
+ }
+ TODO(converter.getCurrentLocation(),
+ "target map clause operand unsupported bound type");
+ };
- ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
- cp.processNumThreads(stmtCtx, clauseOps);
- cp.processProcBind(clauseOps);
- cp.processDefault();
- cp.processAllocate(clauseOps);
+ auto cloneBounds = [cloneBound](llvm::ArrayRef<mlir::Value> bounds) {
+ llvm::SmallVector<mlir::Value> clonedBounds;
+ for (mlir::Value bound : bounds)
+ clonedBounds.emplace_back(cloneBound(bound));
+ return clonedBounds;
+ };
- if (!outerCombined)
- cp.processReduction(currentLocation, clauseOps, &reductionTypes,
- &reductionSyms);
+ // Bind the symbols to their corresponding block arguments.
+ for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) {
+ const mlir::BlockArgument &arg = region.getArgument(argIndex);
+ // Avoid capture of a reference to a structured binding.
+ const Fortran::semantics::Symbol *sym = argSymbol;
+ // Structure component symbols don't have bindings.
+ if (sym->owner().IsDerivedType())
+ continue;
+ fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
+ extVal.match(
+ [&](const fir::BoxValue &v) {
+ converter.bindSymbol(*sym,
+ fir::BoxValue(arg, cloneBounds(v.getLBounds()),
+ v.getExplicitParameters(),
+ v.getExplicitExtents()));
+ },
+ [&](const fir::MutableBoxValue &v) {
+ converter.bindSymbol(
+ *sym, fir::MutableBoxValue(arg, cloneBounds(v.getLBounds()),
+ v.getMutableProperties()));
+ },
+ [&](const fir::ArrayBoxValue &v) {
+ converter.bindSymbol(
+ *sym, fir::ArrayBoxValue(arg, cloneBounds(v.getExtents()),
+ cloneBounds(v.getLBounds()),
+ v.getSourceBox()));
+ },
+ [&](const fir::CharArrayBoxValue &v) {
+ converter.bindSymbol(
+ *sym, fir::CharArrayBoxValue(arg, cloneBound(v.getLen()),
+ cloneBounds(v.getExtents()),
+ cloneBounds(v.getLBounds())));
+ },
+ [&](const fir::CharBoxValue &v) {
+ converter.bindSymbol(*sym,
+ fir::CharBoxValue(arg, cloneBound(v.getLen())));
+ },
+ [&](const fir::UnboxedValue &v) { converter.bindSymbol(*sym, arg); },
+ [&](const auto &) {
+ TODO(converter.getCurrentLocation(),
+ "target map clause operand unsupported type");
+ });
+ }
- if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
- clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
+ // Check if cloning the bounds introduced any dependency on the outer region.
+ // If so, then either clone them as well if they are MemoryEffectFree, or else
+ // copy them to a new temporary and add them to the map and block_argument
+ // lists and replace their uses with the new temporary.
+ llvm::SetVector<mlir::Value> valuesDefinedAbove;
+ mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
+ while (!valuesDefinedAbove.empty()) {
+ for (mlir::Value val : valuesDefinedAbove) {
+ mlir::Operation *valOp = val.getDefiningOp();
+ if (mlir::isMemoryEffectFree(valOp)) {
+ mlir::Operation *clonedOp = valOp->clone();
+ regionBlock->push_front(clonedOp);
+ val.replaceUsesWithIf(
+ clonedOp->getResult(0), [regionBlock](mlir::OpOperand &use) {
+ return use.getOwner()->getBlock() == regionBlock;
+ });
+ } else {
+ auto savedIP = firOpBuilder.getInsertionPoint();
+ firOpBuilder.setInsertionPointAfter(valOp);
+ auto copyVal =
+ firOpBuilder.createTemporary(val.getLoc(), val.getType());
+ firOpBuilder.createStoreWithConvert(copyVal.getLoc(), val, copyVal);
- auto reductionCallback = [&](mlir::Operation *op) {
- llvm::SmallVector<mlir::Location> locs(clauseOps.reductionVars.size(),
- currentLocation);
- auto *block =
- firOpBuilder.createBlock(&op->getRegion(0), {}, reductionTypes, locs);
- for (auto [arg, prv] :
- llvm::zip_equal(reductionSyms, block->getArguments())) {
- converter.bindSymbol(*arg, prv);
+ llvm::SmallVector<mlir::Value> bounds;
+ std::stringstream name;
+ firOpBuilder.setInsertionPoint(targetOp);
+ mlir::Value mapOp = createMapInfoOp(
+ firOpBuilder, copyVal.getLoc(), copyVal, mlir::Value{}, name.str(),
+ bounds, llvm::SmallVector<mlir::Value>{},
+ static_cast<
+ std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
+ mlir::omp::VariableCaptureKind::ByCopy, copyVal.getType());
+ targetOp.getMapOperandsMutable().append(mapOp);
+ mlir::Value clonedValArg =
+ region.addArgument(copyVal.getType(), copyVal.getLoc());
+ firOpBuilder.setInsertionPointToStart(regionBlock);
+ auto loadOp = firOpBuilder.create<fir::LoadOp>(clonedValArg.getLoc(),
+ clonedValArg);
+ val.replaceUsesWithIf(
+ loadOp->getResult(0), [regionBlock](mlir::OpOperand &use) {
+ return use.getOwner()->getBlock() == regionBlock;
+ });
+ firOpBuilder.setInsertionPoint(regionBlock, savedIP);
+ }
}
- return reductionSyms;
- };
-
- OpWithBodyGenInfo genInfo =
- OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
- .setGenNested(genNested)
- .setOuterCombined(outerCombined)
- .setClauses(&clauseList)
- .setReductions(&reductionSyms, &reductionTypes)
- .setGenRegionEntryCb(reductionCallback);
+ valuesDefinedAbove.clear();
+ mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
+ }
- if (!enableDelayedPrivatization)
- return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
+ // Insert dummy instruction to remember the insertion position. The
+ // marker will be deleted since there are not uses.
+ // In the HLFIR flow there are hlfir.declares inserted above while
+ // setting block arguments.
+ mlir::Value undefMarker = firOpBuilder.create<fir::UndefOp>(
+ targetOp.getOperation()->getLoc(), firOpBuilder.getIndexType());
- bool privatize = !outerCombined;
- DataSharingProcessor dsp(converter, semaCtx, clauseList, eval,
- /*useDelayedPrivatization=*/true, &symTable);
+ // Create blocks for unstructured regions. This has to be done since
+ // blocks are initially allocated with the function as the parent region.
+ if (eval.lowerAsUnstructured()) {
+ Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp,
+ mlir::omp::YieldOp>(
+ firOpBuilder, eval.getNestedEvaluations());
+ }
- if (privatize)
- dsp.processStep1(&clauseOps, &privateSyms);
+ firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
- auto genRegionEntryCB = [&](mlir::Operation *op) {
- auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
+ // Create the insertion point after the marker.
+ firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
+ if (genNested)
+ genNestedEvaluations(converter, eval);
+}
- llvm::SmallVector<mlir::Location> reductionLocs(
- clauseOps.reductionVars.size(), currentLocation);
+template <typename OpTy, typename... Args>
+static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) {
+ auto op = info.converter.getFirOpBuilder().create<OpTy>(
+ info.loc, std::forward<Args>(args)...);
+ createBodyOfOp<OpTy>(op, info);
+ return op;
+}
- mlir::OperandRange privateVars = parallelOp.getPrivateVars();
- mlir::Region ®ion = parallelOp.getRegion();
+//===----------------------------------------------------------------------===//
+// Code generation functions for clauses
+//===----------------------------------------------------------------------===//
- llvm::SmallVector<mlir::Type> privateVarTypes = reductionTypes;
- privateVarTypes.reserve(privateVarTypes.size() + privateVars.size());
- llvm::transform(privateVars, std::back_inserter(privateVarTypes),
- [](mlir::Value v) { return v.getType(); });
+static void genCriticalDeclareClauses(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
+ mlir::omp::CriticalClauseOps &clauseOps, llvm::StringRef name) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processHint(clauseOps);
+ clauseOps.nameAttr =
+ mlir::StringAttr::get(converter.getFirOpBuilder().getContext(), name);
+}
- llvm::SmallVector<mlir::Location> privateVarLocs = reductionLocs;
- privateVarLocs.reserve(privateVarLocs.size() + privateVars.size());
- llvm::transform(privateVars, std::back_inserter(privateVarLocs),
- [](mlir::Value v) { return v.getLoc(); });
+static void genFlushClauses(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const std::optional<Fortran::parser::OmpObjectList> &objects,
+ const std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>
+ &clauses,
+ mlir::Location loc, llvm::SmallVectorImpl<mlir::Value> &operandRange) {
+ if (objects)
+ genObjectList2(*objects, converter, operandRange);
+
+ if (clauses && clauses->size() > 0)
+ TODO(converter.getCurrentLocation(), "Handle OmpMemoryOrderClause");
+}
- firOpBuilder.createBlock(®ion, /*insertPt=*/{}, privateVarTypes,
- privateVarLocs);
+static void
+genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const Fortran::parser::OmpClauseList &clauses,
+ mlir::Location loc,
+ mlir::omp::OrderedRegionClauseOps &clauseOps) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processTODO<clause::Simd>(loc, llvm::omp::Directive::OMPD_ordered);
+}
- llvm::SmallVector<const Fortran::semantics::Symbol *> allSymbols =
- reductionSyms;
- allSymbols.append(privateSyms);
- for (auto [arg, prv] : llvm::zip_equal(allSymbols, region.getArguments())) {
- converter.bindSymbol(*arg, prv);
- }
+static void genParallelClauses(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::StatementContext &stmtCtx,
+ const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
+ bool processReduction, mlir::omp::ParallelClauseOps &clauseOps,
+ llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processAllocate(clauseOps);
+ cp.processDefault();
+ cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
+ cp.processProcBind(clauseOps);
- return allSymbols;
- };
+ if (processReduction) {
+ cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
+ if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
+ clauseOps.reductionByRefAttr = converter.getFirOpBuilder().getUnitAttr();
+ }
- // TODO Merge with the reduction CB.
- genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp);
- return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
+ cp.processNumThreads(stmtCtx, clauseOps);
}
-static mlir::omp::SectionOp
-genSectionOp(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location currentLocation,
- const Fortran::parser::OmpClauseList §ionsClauseList) {
- // Currently only private/firstprivate clause is handled, and
- // all privatization is done within `omp.section` operations.
- return genOpWithBody<mlir::omp::SectionOp>(
- OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
- .setGenNested(genNested)
- .setClauses(§ionsClauseList));
+static void genSectionsClauses(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const Fortran::parser::OmpClauseList &clauses,
+ mlir::Location loc,
+ bool clausesFromBeginSections,
+ mlir::omp::SectionsClauseOps &clauseOps) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ if (clausesFromBeginSections) {
+ cp.processAllocate(clauseOps);
+ cp.processSectionsReduction(loc, clauseOps);
+ // TODO Support delayed privatization.
+ } else {
+ cp.processNowait(clauseOps);
+ }
}
-static mlir::omp::SingleOp
-genSingleOp(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location currentLocation,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList &endClauseList) {
- mlir::omp::SingleClauseOps clauseOps;
+static void genSimdLoopClauses(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::StatementContext &stmtCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
+ mlir::omp::SimdLoopClauseOps &clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processCollapse(loc, eval, clauseOps, iv);
+ cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps);
+ cp.processReduction(loc, clauseOps);
+ cp.processSafelen(clauseOps);
+ cp.processSimdlen(clauseOps);
+ clauseOps.loopInclusiveAttr = converter.getFirOpBuilder().getUnitAttr();
+ // TODO Support delayed privatization.
- ClauseProcessor cp(converter, semaCtx, beginClauseList);
- cp.processAllocate(clauseOps);
+ cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
+ clause::Nontemporal, clause::Order>(
+ loc, llvm::omp::Directive::OMPD_simd);
+}
+
+static void genSingleClauses(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const Fortran::parser::OmpClauseList &beginClauses,
+ const Fortran::parser::OmpClauseList &endClauses,
+ mlir::Location loc,
+ mlir::omp::SingleClauseOps &clauseOps) {
+ ClauseProcessor bcp(converter, semaCtx, beginClauses);
+ bcp.processAllocate(clauseOps);
// TODO Support delayed privatization.
- ClauseProcessor ecp(converter, semaCtx, endClauseList);
+ ClauseProcessor ecp(converter, semaCtx, endClauses);
+ ecp.processCopyprivate(loc, clauseOps);
ecp.processNowait(clauseOps);
- ecp.processCopyprivate(currentLocation, clauseOps);
+}
- return genOpWithBody<mlir::omp::SingleOp>(
- OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
- .setGenNested(genNested)
- .setClauses(&beginClauseList),
- clauseOps);
+static void genTargetClauses(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::StatementContext &stmtCtx,
+ const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
+ bool processHostOnlyClauses, bool processReduction,
+ mlir::omp::TargetClauseOps &clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &mapSyms,
+ llvm::SmallVectorImpl<mlir::Location> &mapSymLocs,
+ llvm::SmallVectorImpl<mlir::Type> &mapSymTypes) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processDepend(clauseOps);
+ cp.processDevice(stmtCtx, clauseOps);
+ cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
+ cp.processMap(loc, stmtCtx, clauseOps, &mapSyms, &mapSymLocs, &mapSymTypes);
+ cp.processThreadLimit(stmtCtx, clauseOps);
+ // TODO Support delayed privatization.
+
+ if (processHostOnlyClauses)
+ cp.processNowait(clauseOps);
+
+ cp.processTODO<clause::Allocate, clause::Defaultmap, clause::Firstprivate,
+ clause::HasDeviceAddr, clause::InReduction,
+ clause::IsDevicePtr, clause::Private, clause::Reduction,
+ clause::UsesAllocators>(loc,
+ llvm::omp::Directive::OMPD_target);
}
-static mlir::omp::TaskOp
-genTaskOp(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location currentLocation,
- const Fortran::parser::OmpClauseList &clauseList) {
- Fortran::lower::StatementContext stmtCtx;
- mlir::omp::TaskClauseOps clauseOps;
+static void genTargetDataClauses(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::StatementContext &stmtCtx,
+ const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
+ mlir::omp::TargetDataClauseOps &clauseOps,
+ llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
+ llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processDevice(stmtCtx, clauseOps);
+ cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps);
+ cp.processMap(loc, stmtCtx, clauseOps);
+ cp.processUseDeviceAddr(clauseOps, useDeviceTypes, useDeviceLocs,
+ useDeviceSyms);
+ cp.processUseDevicePtr(clauseOps, useDeviceTypes, useDeviceLocs,
+ useDeviceSyms);
- ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
+ // This function implements the deprecated functionality of use_device_ptr
+ // that allows users to provide non-CPTR arguments to it with the caveat
+ // that the compiler will treat them as use_device_addr. A lot of legacy
+ // code may still depend on this functionality, so we should support it
+ // in some manner. We do so currently by simply shifting non-cptr operands
+ // from the use_device_ptr list into the front of the use_device_addr list
+ // whilst maintaining the ordering of useDeviceLocs, useDeviceSyms and
+ // useDeviceTypes to use_device_ptr/use_device_addr input for BlockArg
+ // ordering.
+ // TODO: Perhaps create a user provideable compiler option that will
+ // re-introduce a hard-error rather than a warning in these cases.
+ promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(clauseOps, useDeviceTypes,
+ useDeviceLocs, useDeviceSyms);
+}
+
+static void genTargetEnterExitUpdateDataClauses(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::StatementContext &stmtCtx,
+ const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
+ llvm::omp::Directive directive,
+ mlir::omp::TargetEnterExitUpdateDataClauseOps &clauseOps) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processDepend(clauseOps);
+ cp.processDevice(stmtCtx, clauseOps);
+ cp.processIf(directive, clauseOps);
+ cp.processNowait(clauseOps);
+
+ if (directive == llvm::omp::Directive::OMPD_target_update) {
+ cp.processMotionClauses<clause::To>(stmtCtx, clauseOps);
+ cp.processMotionClauses<clause::From>(stmtCtx, clauseOps);
+ } else {
+ cp.processMap(loc, stmtCtx, clauseOps);
+ }
+}
+
+static void genTaskClauses(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::StatementContext &stmtCtx,
+ const Fortran::parser::OmpClauseList &clauses,
+ mlir::Location loc,
+ mlir::omp::TaskClauseOps &clauseOps) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processDefault();
+ cp.processDepend(clauseOps);
cp.processFinal(stmtCtx, clauseOps);
- cp.processUntied(clauseOps);
+ cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
cp.processMergeable(clauseOps);
cp.processPriority(stmtCtx, clauseOps);
- cp.processDepend(clauseOps);
+ cp.processUntied(clauseOps);
// TODO Support delayed privatization.
- cp.processTODO<clause::InReduction, clause::Detach, clause::Affinity>(
- currentLocation, llvm::omp::Directive::OMPD_task);
+ cp.processTODO<clause::Affinity, clause::Detach, clause::InReduction>(
+ loc, llvm::omp::Directive::OMPD_task);
+}
- return genOpWithBody<mlir::omp::TaskOp>(
- OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
- .setGenNested(genNested)
- .setClauses(&clauseList),
- clauseOps);
+static void genTaskgroupClauses(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const Fortran::parser::OmpClauseList &clauses,
+ mlir::Location loc,
+ mlir::omp::TaskgroupClauseOps &clauseOps) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processAllocate(clauseOps);
+ cp.processTODO<clause::TaskReduction>(loc,
+ llvm::omp::Directive::OMPD_taskgroup);
}
-static mlir::omp::TaskgroupOp
-genTaskgroupOp(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location currentLocation,
- const Fortran::parser::OmpClauseList &clauseList) {
- mlir::omp::TaskgroupClauseOps clauseOps;
+static void genTaskwaitClauses(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const Fortran::parser::OmpClauseList &clauses,
+ mlir::Location loc,
+ mlir::omp::TaskwaitClauseOps &clauseOps) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processTODO<clause::Depend, clause::Nowait>(
+ loc, llvm::omp::Directive::OMPD_taskwait);
+}
- ClauseProcessor cp(converter, semaCtx, clauseList);
+static void genTeamsClauses(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::StatementContext &stmtCtx,
+ const Fortran::parser::OmpClauseList &clauses,
+ mlir::Location loc,
+ mlir::omp::TeamsClauseOps &clauseOps) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
- cp.processTODO<clause::TaskReduction>(currentLocation,
- llvm::omp::Directive::OMPD_taskgroup);
+ cp.processDefault();
+ cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
+ cp.processNumTeams(stmtCtx, clauseOps);
+ cp.processThreadLimit(stmtCtx, clauseOps);
+ // TODO Support delayed privatization.
- return genOpWithBody<mlir::omp::TaskgroupOp>(
- OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
- .setGenNested(genNested)
- .setClauses(&clauseList),
- clauseOps);
+ cp.processTODO<clause::Reduction>(loc, llvm::omp::Directive::OMPD_teams);
}
-// This helper function implements the functionality of "promoting"
-// non-CPTR arguments of use_device_ptr to use_device_addr
-// arguments (automagic conversion of use_device_ptr ->
-// use_device_addr in these cases). The way we do so currently is
-// through the shuffling of operands from the devicePtrOperands to
-// deviceAddrOperands where neccesary and re-organizing the types,
-// locations and symbols to maintain the correct ordering of ptr/addr
-// input -> BlockArg.
-//
-// This effectively implements some deprecated OpenMP functionality
-// that some legacy applications unfortunately depend on
-// (deprecated in specification version 5.2):
-//
-// "If a list item in a use_device_ptr clause is not of type C_PTR,
-// the behavior is as if the list item appeared in a use_device_addr
-// clause. Support for such list items in a use_device_ptr clause
-// is deprecated."
-static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
- mlir::omp::UseDeviceClauseOps &clauseOps,
- llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- &useDeviceSymbols) {
- auto moveElementToBack = [](size_t idx, auto &vector) {
- auto *iter = std::next(vector.begin(), idx);
- vector.push_back(*iter);
- vector.erase(iter);
- };
+static void genWsloopClauses(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::StatementContext &stmtCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OmpClauseList &beginClauses,
+ const Fortran::parser::OmpClauseList *endClauses, mlir::Location loc,
+ mlir::omp::WsloopClauseOps &clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
+ llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ ClauseProcessor bcp(converter, semaCtx, beginClauses);
+ bcp.processCollapse(loc, eval, clauseOps, iv);
+ bcp.processOrdered(clauseOps);
+ bcp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
+ bcp.processSchedule(stmtCtx, clauseOps);
+ clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr();
+ // TODO Support delayed privatization.
- // Iterate over our use_device_ptr list and shift all non-cptr arguments into
- // use_device_addr.
- for (auto *it = clauseOps.useDevicePtrVars.begin();
- it != clauseOps.useDevicePtrVars.end();) {
- if (!fir::isa_builtin_cptr_type(fir::unwrapRefType(it->getType()))) {
- clauseOps.useDeviceAddrVars.push_back(*it);
- // We have to shuffle the symbols around as well, to maintain
- // the correct Input -> BlockArg for use_device_ptr/use_device_addr.
- // NOTE: However, as map's do not seem to be included currently
- // this isn't as pertinent, but we must try to maintain for
- // future alterations. I believe the reason they are not currently
- // is that the BlockArg assign/lowering needs to be extended
- // to a greater set of types.
- auto idx = std::distance(clauseOps.useDevicePtrVars.begin(), it);
- moveElementToBack(idx, useDeviceTypes);
- moveElementToBack(idx, useDeviceLocs);
- moveElementToBack(idx, useDeviceSymbols);
- it = clauseOps.useDevicePtrVars.erase(it);
- continue;
+ if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
+ clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
+
+ if (endClauses) {
+ ClauseProcessor ecp(converter, semaCtx, *endClauses);
+ ecp.processNowait(clauseOps);
+ }
+
+ bcp.processTODO<clause::Allocate, clause::Linear, clause::Order>(
+ loc, llvm::omp::Directive::OMPD_do);
+}
+
+//===----------------------------------------------------------------------===//
+// Code generation functions for leaf constructs
+//===----------------------------------------------------------------------===//
+
+static mlir::omp::BarrierOp
+genBarrierOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc) {
+ return converter.getFirOpBuilder().create<mlir::omp::BarrierOp>(loc);
+}
+
+static mlir::omp::CriticalOp
+genCriticalOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, bool genNested,
+ mlir::Location loc,
+ const Fortran::parser::OmpClauseList &clauseList,
+ const std::optional<Fortran::parser::Name> &name) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::FlatSymbolRefAttr nameAttr;
+
+ if (name) {
+ std::string nameStr = name->ToString();
+ mlir::ModuleOp mod = firOpBuilder.getModule();
+ auto global = mod.lookupSymbol<mlir::omp::CriticalDeclareOp>(nameStr);
+ if (!global) {
+ mlir::omp::CriticalClauseOps clauseOps;
+ genCriticalDeclareClauses(converter, semaCtx, clauseList, loc, clauseOps,
+ nameStr);
+
+ mlir::OpBuilder modBuilder(mod.getBodyRegion());
+ global = modBuilder.create<mlir::omp::CriticalDeclareOp>(loc, clauseOps);
}
- ++it;
+ nameAttr = mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(),
+ global.getSymName());
}
+
+ return genOpWithBody<mlir::omp::CriticalOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(genNested),
+ nameAttr);
}
-static mlir::omp::TargetDataOp
-genTargetDataOp(Fortran::lower::AbstractConverter &converter,
+static mlir::omp::DistributeOp
+genDistributeOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location currentLocation,
+ mlir::Location loc,
const Fortran::parser::OmpClauseList &clauseList) {
- Fortran::lower::StatementContext stmtCtx;
- mlir::omp::TargetDataClauseOps clauseOps;
- llvm::SmallVector<mlir::Type> useDeviceTypes;
- llvm::SmallVector<mlir::Location> useDeviceLocs;
- llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSyms;
+ TODO(loc, "Distribute construct");
+ return nullptr;
+}
- ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps);
- cp.processDevice(stmtCtx, clauseOps);
- cp.processUseDevicePtr(clauseOps, useDeviceTypes, useDeviceLocs,
- useDeviceSyms);
- cp.processUseDeviceAddr(clauseOps, useDeviceTypes, useDeviceLocs,
- useDeviceSyms);
+static mlir::omp::FlushOp
+genFlushOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const std::optional<Fortran::parser::OmpObjectList> &objectList,
+ const std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>
+ &clauseList) {
+ llvm::SmallVector<mlir::Value> operandRange;
+ genFlushClauses(converter, semaCtx, objectList, clauseList, loc,
+ operandRange);
+
+ return converter.getFirOpBuilder().create<mlir::omp::FlushOp>(
+ converter.getCurrentLocation(), operandRange);
+}
- // This function implements the deprecated functionality of use_device_ptr
- // that allows users to provide non-CPTR arguments to it with the caveat
- // that the compiler will treat them as use_device_addr. A lot of legacy
- // code may still depend on this functionality, so we should support it
- // in some manner. We do so currently by simply shifting non-cptr operands
- // from the use_device_ptr list into the front of the use_device_addr list
- // whilst maintaining the ordering of useDeviceLocs, useDeviceSymbols and
- // useDeviceTypes to use_device_ptr/use_device_addr input for BlockArg
- // ordering.
- // TODO: Perhaps create a user provideable compiler option that will
- // re-introduce a hard-error rather than a warning in these cases.
- promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(clauseOps, useDeviceTypes,
- useDeviceLocs, useDeviceSyms);
- cp.processMap(currentLocation, llvm::omp::Directive::OMPD_target_data,
- stmtCtx, clauseOps);
+static mlir::omp::MasterOp
+genMasterOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, bool genNested,
+ mlir::Location loc) {
+ return genOpWithBody<mlir::omp::MasterOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(genNested),
+ /*resultTypes=*/mlir::TypeRange());
+}
+
+static mlir::omp::OrderedOp
+genOrderedOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const Fortran::parser::OmpClauseList &clauseList) {
+ TODO(loc, "OMPD_ordered");
+ return nullptr;
+}
- auto dataOp = converter.getFirOpBuilder().create<mlir::omp::TargetDataOp>(
- currentLocation, clauseOps);
+static mlir::omp::OrderedRegionOp
+genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, bool genNested,
+ mlir::Location loc,
+ const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::omp::OrderedRegionClauseOps clauseOps;
+ genOrderedRegionClauses(converter, semaCtx, clauseList, loc, clauseOps);
- genBodyOfTargetDataOp(converter, semaCtx, eval, genNested, dataOp,
- useDeviceTypes, useDeviceLocs, useDeviceSyms,
- currentLocation);
- return dataOp;
+ return genOpWithBody<mlir::omp::OrderedRegionOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(genNested),
+ clauseOps);
}
-template <typename OpTy>
-static OpTy genTargetEnterExitDataUpdateOp(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- mlir::Location currentLocation,
- const Fortran::parser::OmpClauseList &clauseList) {
+static mlir::omp::ParallelOp
+genParallelOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::SymMap &symTable,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, bool genNested,
+ mlir::Location loc,
+ const Fortran::parser::OmpClauseList &clauseList,
+ bool outerCombined = false) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
- mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps;
+ mlir::omp::ParallelClauseOps clauseOps;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> privateSyms;
+ llvm::SmallVector<mlir::Type> reductionTypes;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
+ genParallelClauses(converter, semaCtx, stmtCtx, clauseList, loc,
+ /*processReduction=*/!outerCombined, clauseOps,
+ reductionTypes, reductionSyms);
- // GCC 9.3.0 emits a (probably) bogus warning about an unused variable.
- [[maybe_unused]] llvm::omp::Directive directive;
- if constexpr (std::is_same_v<OpTy, mlir::omp::TargetEnterDataOp>) {
- directive = llvm::omp::Directive::OMPD_target_enter_data;
- } else if constexpr (std::is_same_v<OpTy, mlir::omp::TargetExitDataOp>) {
- directive = llvm::omp::Directive::OMPD_target_exit_data;
- } else if constexpr (std::is_same_v<OpTy, mlir::omp::TargetUpdateOp>) {
- directive = llvm::omp::Directive::OMPD_target_update;
- } else {
- return nullptr;
- }
+ auto reductionCallback = [&](mlir::Operation *op) {
+ genReductionVars(op, converter, loc, reductionSyms, reductionTypes);
+ return reductionSyms;
+ };
- ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(directive, clauseOps);
- cp.processDevice(stmtCtx, clauseOps);
- cp.processDepend(clauseOps);
- cp.processNowait(clauseOps);
+ OpWithBodyGenInfo genInfo =
+ OpWithBodyGenInfo(converter, semaCtx, loc, eval)
+ .setGenNested(genNested)
+ .setOuterCombined(outerCombined)
+ .setClauses(&clauseList)
+ .setReductions(&reductionSyms, &reductionTypes)
+ .setGenRegionEntryCb(reductionCallback);
- if constexpr (std::is_same_v<OpTy, mlir::omp::TargetUpdateOp>) {
- cp.processMotionClauses<clause::To>(stmtCtx, clauseOps);
- cp.processMotionClauses<clause::From>(stmtCtx, clauseOps);
- } else {
- cp.processMap(currentLocation, directive, stmtCtx, clauseOps);
- }
+ if (!enableDelayedPrivatization)
+ return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
- return firOpBuilder.create<OpTy>(currentLocation, clauseOps);
-}
+ bool privatize = !outerCombined;
+ DataSharingProcessor dsp(converter, semaCtx, clauseList, eval,
+ /*useDelayedPrivatization=*/true, &symTable);
-// This functions creates a block for the body of the targetOp's region. It adds
-// all the symbols present in mapSymbols as block arguments to this block.
-static void
-genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::omp::TargetOp &targetOp,
- llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSyms,
- llvm::ArrayRef<mlir::Location> mapSymLocs,
- llvm::ArrayRef<mlir::Type> mapSymTypes,
- const mlir::Location ¤tLocation) {
- assert(mapSymTypes.size() == mapSymLocs.size());
+ if (privatize)
+ dsp.processStep1(&clauseOps, &privateSyms);
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::Region ®ion = targetOp.getRegion();
+ auto genRegionEntryCB = [&](mlir::Operation *op) {
+ auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
- auto *regionBlock =
- firOpBuilder.createBlock(®ion, {}, mapSymTypes, mapSymLocs);
+ llvm::SmallVector<mlir::Location> reductionLocs(
+ clauseOps.reductionVars.size(), loc);
- // Clones the `bounds` placing them inside the target region and returns them.
- auto cloneBound = [&](mlir::Value bound) {
- if (mlir::isMemoryEffectFree(bound.getDefiningOp())) {
- mlir::Operation *clonedOp = bound.getDefiningOp()->clone();
- regionBlock->push_back(clonedOp);
- return clonedOp->getResult(0);
+ mlir::OperandRange privateVars = parallelOp.getPrivateVars();
+ mlir::Region ®ion = parallelOp.getRegion();
+
+ llvm::SmallVector<mlir::Type> privateVarTypes = reductionTypes;
+ privateVarTypes.reserve(privateVarTypes.size() + privateVars.size());
+ llvm::transform(privateVars, std::back_inserter(privateVarTypes),
+ [](mlir::Value v) { return v.getType(); });
+
+ llvm::SmallVector<mlir::Location> privateVarLocs = reductionLocs;
+ privateVarLocs.reserve(privateVarLocs.size() + privateVars.size());
+ llvm::transform(privateVars, std::back_inserter(privateVarLocs),
+ [](mlir::Value v) { return v.getLoc(); });
+
+ firOpBuilder.createBlock(®ion, /*insertPt=*/{}, privateVarTypes,
+ privateVarLocs);
+
+ llvm::SmallVector<const Fortran::semantics::Symbol *> allSymbols =
+ reductionSyms;
+ allSymbols.append(privateSyms);
+ for (auto [arg, prv] : llvm::zip_equal(allSymbols, region.getArguments())) {
+ converter.bindSymbol(*arg, prv);
}
- TODO(converter.getCurrentLocation(),
- "target map clause operand unsupported bound type");
- };
- auto cloneBounds = [cloneBound](llvm::ArrayRef<mlir::Value> bounds) {
- llvm::SmallVector<mlir::Value> clonedBounds;
- for (mlir::Value bound : bounds)
- clonedBounds.emplace_back(cloneBound(bound));
- return clonedBounds;
+ return allSymbols;
};
- // Bind the symbols to their corresponding block arguments.
- for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) {
- const mlir::BlockArgument &arg = region.getArgument(argIndex);
- // Avoid capture of a reference to a structured binding.
- const Fortran::semantics::Symbol *sym = argSymbol;
- // Structure component symbols don't have bindings.
- if (sym->owner().IsDerivedType())
- continue;
- fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
- extVal.match(
- [&](const fir::BoxValue &v) {
- converter.bindSymbol(*sym,
- fir::BoxValue(arg, cloneBounds(v.getLBounds()),
- v.getExplicitParameters(),
- v.getExplicitExtents()));
- },
- [&](const fir::MutableBoxValue &v) {
- converter.bindSymbol(
- *sym, fir::MutableBoxValue(arg, cloneBounds(v.getLBounds()),
- v.getMutableProperties()));
- },
- [&](const fir::ArrayBoxValue &v) {
- converter.bindSymbol(
- *sym, fir::ArrayBoxValue(arg, cloneBounds(v.getExtents()),
- cloneBounds(v.getLBounds()),
- v.getSourceBox()));
- },
- [&](const fir::CharArrayBoxValue &v) {
- converter.bindSymbol(
- *sym, fir::CharArrayBoxValue(arg, cloneBound(v.getLen()),
- cloneBounds(v.getExtents()),
- cloneBounds(v.getLBounds())));
- },
- [&](const fir::CharBoxValue &v) {
- converter.bindSymbol(*sym,
- fir::CharBoxValue(arg, cloneBound(v.getLen())));
- },
- [&](const fir::UnboxedValue &v) { converter.bindSymbol(*sym, arg); },
- [&](const auto &) {
- TODO(converter.getCurrentLocation(),
- "target map clause operand unsupported type");
- });
- }
+ // TODO Merge with the reduction CB.
+ genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp);
+ return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
+}
- // Check if cloning the bounds introduced any dependency on the outer region.
- // If so, then either clone them as well if they are MemoryEffectFree, or else
- // copy them to a new temporary and add them to the map and block_argument
- // lists and replace their uses with the new temporary.
- llvm::SetVector<mlir::Value> valuesDefinedAbove;
- mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
- while (!valuesDefinedAbove.empty()) {
- for (mlir::Value val : valuesDefinedAbove) {
- mlir::Operation *valOp = val.getDefiningOp();
- if (mlir::isMemoryEffectFree(valOp)) {
- mlir::Operation *clonedOp = valOp->clone();
- regionBlock->push_front(clonedOp);
- val.replaceUsesWithIf(
- clonedOp->getResult(0), [regionBlock](mlir::OpOperand &use) {
- return use.getOwner()->getBlock() == regionBlock;
- });
- } else {
- auto savedIP = firOpBuilder.getInsertionPoint();
- firOpBuilder.setInsertionPointAfter(valOp);
- auto copyVal =
- firOpBuilder.createTemporary(val.getLoc(), val.getType());
- firOpBuilder.createStoreWithConvert(copyVal.getLoc(), val, copyVal);
+static mlir::omp::SectionOp
+genSectionOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, bool genNested,
+ mlir::Location loc,
+ const Fortran::parser::OmpClauseList &clauseList) {
+ // Currently only private/firstprivate clause is handled, and
+ // all privatization is done within `omp.section` operations.
+ return genOpWithBody<mlir::omp::SectionOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, eval)
+ .setGenNested(genNested)
+ .setClauses(&clauseList));
+}
- llvm::SmallVector<mlir::Value> bounds;
- std::stringstream name;
- firOpBuilder.setInsertionPoint(targetOp);
- mlir::Value mapOp = createMapInfoOp(
- firOpBuilder, copyVal.getLoc(), copyVal, mlir::Value{}, name.str(),
- bounds, llvm::SmallVector<mlir::Value>{},
- static_cast<
- std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
- mlir::omp::VariableCaptureKind::ByCopy, copyVal.getType());
- targetOp.getMapOperandsMutable().append(mapOp);
- mlir::Value clonedValArg =
- region.addArgument(copyVal.getType(), copyVal.getLoc());
- firOpBuilder.setInsertionPointToStart(regionBlock);
- auto loadOp = firOpBuilder.create<fir::LoadOp>(clonedValArg.getLoc(),
- clonedValArg);
- val.replaceUsesWithIf(
- loadOp->getResult(0), [regionBlock](mlir::OpOperand &use) {
- return use.getOwner()->getBlock() == regionBlock;
- });
- firOpBuilder.setInsertionPoint(regionBlock, savedIP);
- }
- }
- valuesDefinedAbove.clear();
- mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
- }
+static mlir::omp::SectionsOp
+genSectionsOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const mlir::omp::SectionsClauseOps &clauseOps) {
+ return genOpWithBody<mlir::omp::SectionsOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(false),
+ clauseOps);
+}
- // Insert dummy instruction to remember the insertion position. The
- // marker will be deleted since there are not uses.
- // In the HLFIR flow there are hlfir.declares inserted above while
- // setting block arguments.
- mlir::Value undefMarker = firOpBuilder.create<fir::UndefOp>(
- targetOp.getOperation()->getLoc(), firOpBuilder.getIndexType());
+static mlir::omp::SimdLoopOp
+genSimdLoopOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const Fortran::parser::OmpClauseList &clauseList) {
+ DataSharingProcessor dsp(converter, semaCtx, clauseList, eval);
+ dsp.processStep1();
- // Create blocks for unstructured regions. This has to be done since
- // blocks are initially allocated with the function as the parent region.
- if (eval.lowerAsUnstructured()) {
- Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp,
- mlir::omp::YieldOp>(
- firOpBuilder, eval.getNestedEvaluations());
- }
+ Fortran::lower::StatementContext stmtCtx;
+ mlir::omp::SimdLoopClauseOps clauseOps;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
+ genSimdLoopClauses(converter, semaCtx, stmtCtx, eval, clauseList, loc,
+ clauseOps, iv);
- firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
+ auto *nestedEval =
+ getCollapsedLoopEval(eval, Fortran::lower::getCollapseValue(clauseList));
+
+ auto ivCallback = [&](mlir::Operation *op) {
+ return genLoopVars(op, converter, loc, iv);
+ };
+
+ return genOpWithBody<mlir::omp::SimdLoopOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
+ .setClauses(&clauseList)
+ .setDataSharingProcessor(&dsp)
+ .setGenRegionEntryCb(ivCallback),
+ clauseOps);
+}
+
+static mlir::omp::SingleOp
+genSingleOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, bool genNested,
+ mlir::Location loc,
+ const Fortran::parser::OmpClauseList &beginClauseList,
+ const Fortran::parser::OmpClauseList &endClauseList) {
+ mlir::omp::SingleClauseOps clauseOps;
+ genSingleClauses(converter, semaCtx, beginClauseList, endClauseList, loc,
+ clauseOps);
- // Create the insertion point after the marker.
- firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
- if (genNested)
- genNestedEvaluations(converter, eval);
+ return genOpWithBody<mlir::omp::SingleOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, eval)
+ .setGenNested(genNested)
+ .setClauses(&beginClauseList),
+ clauseOps);
}
static mlir::omp::TargetOp
genTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location currentLocation,
+ mlir::Location loc,
const Fortran::parser::OmpClauseList &clauseList,
- llvm::omp::Directive directive, bool outerCombined = false) {
+ bool outerCombined = false) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
+
+ bool processHostOnlyClauses =
+ !llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp())
+ .getIsTargetDevice();
+
mlir::omp::TargetClauseOps clauseOps;
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSyms;
llvm::SmallVector<mlir::Location> mapSymLocs;
llvm::SmallVector<mlir::Type> mapSymTypes;
-
- ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
- cp.processDevice(stmtCtx, clauseOps);
- cp.processThreadLimit(stmtCtx, clauseOps);
- cp.processDepend(clauseOps);
- cp.processNowait(clauseOps);
- cp.processMap(currentLocation, directive, stmtCtx, clauseOps, &mapSyms,
- &mapSymLocs, &mapSymTypes);
- // TODO Support delayed privatization.
-
- cp.processTODO<clause::Private, clause::Firstprivate, clause::IsDevicePtr,
- clause::HasDeviceAddr, clause::Reduction, clause::InReduction,
- clause::Allocate, clause::UsesAllocators, clause::Defaultmap>(
- currentLocation, llvm::omp::Directive::OMPD_target);
+ genTargetClauses(converter, semaCtx, stmtCtx, clauseList, loc,
+ processHostOnlyClauses, /*processReduction=*/outerCombined,
+ clauseOps, mapSyms, mapSymLocs, mapSymTypes);
// 5.8.1 Implicit Data-Mapping Attribute Rules
// The following code follows the implicit data-mapping rules to map all the
@@ -1131,338 +1584,145 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
};
Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap);
- auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
- currentLocation, clauseOps);
-
- genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSyms,
- mapSymLocs, mapSymTypes, currentLocation);
-
- return targetOp;
-}
-
-static mlir::omp::TeamsOp
-genTeamsOp(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location currentLocation,
- const Fortran::parser::OmpClauseList &clauseList,
- bool outerCombined = false) {
- Fortran::lower::StatementContext stmtCtx;
- mlir::omp::TeamsClauseOps clauseOps;
-
- ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
- cp.processAllocate(clauseOps);
- cp.processDefault();
- cp.processNumTeams(stmtCtx, clauseOps);
- cp.processThreadLimit(stmtCtx, clauseOps);
- // TODO Support delayed privatization.
-
- cp.processTODO<clause::Reduction>(currentLocation,
- llvm::omp::Directive::OMPD_teams);
-
- return genOpWithBody<mlir::omp::TeamsOp>(
- OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
- .setGenNested(genNested)
- .setOuterCombined(outerCombined)
- .setClauses(&clauseList),
- clauseOps);
-}
-
-/// Extract the list of function and variable symbols affected by the given
-/// 'declare target' directive and return the intended device type for them.
-static void getDeclareTargetInfo(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
- mlir::omp::DeclareTargetClauseOps &clauseOps,
- llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
- const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
- declareTargetConstruct.t);
- if (const auto *objectList{
- Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
- ObjectList objects{makeList(*objectList, semaCtx)};
- // Case: declare target(func, var1, var2)
- gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to,
- symbolAndClause);
- } else if (const auto *clauseList{
- Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
- spec.u)}) {
- if (clauseList->v.empty()) {
- // Case: declare target, implicit capture of function
- symbolAndClause.emplace_back(
- mlir::omp::DeclareTargetCaptureClause::to,
- eval.getOwningProcedure()->getSubprogramSymbol());
- }
-
- ClauseProcessor cp(converter, semaCtx, *clauseList);
- cp.processTo(symbolAndClause);
- cp.processEnter(symbolAndClause);
- cp.processLink(symbolAndClause);
- cp.processDeviceType(clauseOps);
- cp.processTODO<clause::Indirect>(converter.getCurrentLocation(),
- llvm::omp::Directive::OMPD_declare_target);
- }
-}
-
-static void collectDeferredDeclareTargets(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
- llvm::SmallVectorImpl<Fortran::lower::OMPDeferredDeclareTargetInfo>
- &deferredDeclareTarget) {
- mlir::omp::DeclareTargetClauseOps clauseOps;
- llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
- getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
- clauseOps, symbolAndClause);
- // Return the device type only if at least one of the targets for the
- // directive is a function or subroutine
- mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
-
- for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
- mlir::Operation *op = mod.lookupSymbol(converter.mangleName(
- std::get<const Fortran::semantics::Symbol &>(symClause)));
-
- if (!op) {
- deferredDeclareTarget.push_back({std::get<0>(symClause),
- clauseOps.deviceType,
- std::get<1>(symClause)});
- }
- }
-}
-
-static std::optional<mlir::omp::DeclareTargetDeviceType>
-getDeclareTargetFunctionDevice(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenMPDeclareTargetConstruct
- &declareTargetConstruct) {
- mlir::omp::DeclareTargetClauseOps clauseOps;
- llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
- getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
- clauseOps, symbolAndClause);
-
- // Return the device type only if at least one of the targets for the
- // directive is a function or subroutine
- mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
- for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
- mlir::Operation *op = mod.lookupSymbol(converter.mangleName(
- std::get<const Fortran::semantics::Symbol &>(symClause)));
-
- if (mlir::isa_and_nonnull<mlir::func::FuncOp>(op))
- return clauseOps.deviceType;
- }
-
- return std::nullopt;
-}
-
-//===----------------------------------------------------------------------===//
-// genOMP() Code generation helper functions
-//===----------------------------------------------------------------------===//
-
-static void
-genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- const Fortran::parser::OpenMPSimpleStandaloneConstruct
- &simpleStandaloneConstruct) {
- const auto &directive =
- std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
- simpleStandaloneConstruct.t);
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- const auto &opClauseList =
- std::get<Fortran::parser::OmpClauseList>(simpleStandaloneConstruct.t);
- mlir::Location currentLocation = converter.genLocation(directive.source);
-
- switch (directive.v) {
- default:
- break;
- case llvm::omp::Directive::OMPD_barrier:
- firOpBuilder.create<mlir::omp::BarrierOp>(currentLocation);
- break;
- case llvm::omp::Directive::OMPD_taskwait: {
- mlir::omp::TaskwaitClauseOps clauseOps;
- ClauseProcessor cp(converter, semaCtx, opClauseList);
- cp.processTODO<clause::Depend, clause::Nowait>(
- currentLocation, llvm::omp::Directive::OMPD_taskwait);
- firOpBuilder.create<mlir::omp::TaskwaitOp>(currentLocation, clauseOps);
- break;
- }
- case llvm::omp::Directive::OMPD_taskyield:
- firOpBuilder.create<mlir::omp::TaskyieldOp>(currentLocation);
- break;
- case llvm::omp::Directive::OMPD_target_data:
- genTargetDataOp(converter, semaCtx, eval, genNested, currentLocation,
- opClauseList);
- break;
- case llvm::omp::Directive::OMPD_target_enter_data:
- genTargetEnterExitDataUpdateOp<mlir::omp::TargetEnterDataOp>(
- converter, semaCtx, currentLocation, opClauseList);
- break;
- case llvm::omp::Directive::OMPD_target_exit_data:
- genTargetEnterExitDataUpdateOp<mlir::omp::TargetExitDataOp>(
- converter, semaCtx, currentLocation, opClauseList);
- break;
- case llvm::omp::Directive::OMPD_target_update:
- genTargetEnterExitDataUpdateOp<mlir::omp::TargetUpdateOp>(
- converter, semaCtx, currentLocation, opClauseList);
- break;
- case llvm::omp::Directive::OMPD_ordered:
- TODO(currentLocation, "OMPD_ordered");
- }
-}
-
-static void
-genOmpFlush(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
- llvm::SmallVector<mlir::Value, 4> operandRange;
- if (const auto &ompObjectList =
- std::get<std::optional<Fortran::parser::OmpObjectList>>(
- flushConstruct.t))
- genObjectList2(*ompObjectList, converter, operandRange);
- const auto &memOrderClause =
- std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
- flushConstruct.t);
- if (memOrderClause && memOrderClause->size() > 0)
- TODO(converter.getCurrentLocation(), "Handle OmpMemoryOrderClause");
- converter.getFirOpBuilder().create<mlir::omp::FlushOp>(
- converter.getCurrentLocation(), operandRange);
-}
-
-static llvm::SmallVector<const Fortran::semantics::Symbol *>
-genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
- mlir::Location &loc,
- llvm::ArrayRef<const Fortran::semantics::Symbol *> args) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- auto ®ion = op->getRegion(0);
-
- std::size_t loopVarTypeSize = 0;
- for (const Fortran::semantics::Symbol *arg : args)
- loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
- mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
- llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
- llvm::SmallVector<mlir::Location> locs(args.size(), loc);
- firOpBuilder.createBlock(®ion, {}, tiv, locs);
- // The argument is not currently in memory, so make a temporary for the
- // argument, and store it there, then bind that location to the argument.
- mlir::Operation *storeOp = nullptr;
- for (auto [argIndex, argSymbol] : llvm::enumerate(args)) {
- mlir::Value indexVal = fir::getBase(region.front().getArgument(argIndex));
- storeOp =
- createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
- }
- firOpBuilder.setInsertionPointAfter(storeOp);
-
- return llvm::SmallVector<const Fortran::semantics::Symbol *>(args);
-}
-
-static llvm::SmallVector<const Fortran::semantics::Symbol *>
-genLoopAndReductionVars(
- mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
- mlir::Location &loc,
- llvm::ArrayRef<const Fortran::semantics::Symbol *> loopArgs,
- llvm::ArrayRef<const Fortran::semantics::Symbol *> reductionArgs,
- llvm::ArrayRef<mlir::Type> reductionTypes) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
- llvm::SmallVector<mlir::Type> blockArgTypes;
- llvm::SmallVector<mlir::Location> blockArgLocs;
- blockArgTypes.reserve(loopArgs.size() + reductionArgs.size());
- blockArgLocs.reserve(blockArgTypes.size());
- mlir::Block *entryBlock;
-
- if (loopArgs.size()) {
- std::size_t loopVarTypeSize = 0;
- for (const Fortran::semantics::Symbol *arg : loopArgs)
- loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
- mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
- std::fill_n(std::back_inserter(blockArgTypes), loopArgs.size(),
- loopVarType);
- std::fill_n(std::back_inserter(blockArgLocs), loopArgs.size(), loc);
- }
- if (reductionArgs.size()) {
- llvm::copy(reductionTypes, std::back_inserter(blockArgTypes));
- std::fill_n(std::back_inserter(blockArgLocs), reductionArgs.size(), loc);
- }
- entryBlock = firOpBuilder.createBlock(&op->getRegion(0), {}, blockArgTypes,
- blockArgLocs);
- // The argument is not currently in memory, so make a temporary for the
- // argument, and store it there, then bind that location to the argument.
- if (loopArgs.size()) {
- mlir::Operation *storeOp = nullptr;
- for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) {
- mlir::Value indexVal =
- fir::getBase(op->getRegion(0).front().getArgument(argIndex));
- storeOp =
- createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
- }
- firOpBuilder.setInsertionPointAfter(storeOp);
- }
- // Bind the reduction arguments to their block arguments
- for (auto [arg, prv] : llvm::zip_equal(
- reductionArgs,
- llvm::drop_begin(entryBlock->getArguments(), loopArgs.size()))) {
- converter.bindSymbol(*arg, prv);
- }
-
- return llvm::SmallVector<const Fortran::semantics::Symbol *>(loopArgs);
+ auto targetOp = firOpBuilder.create<mlir::omp::TargetOp>(loc, clauseOps);
+ genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSyms,
+ mapSymLocs, mapSymTypes, loc);
+ return targetOp;
}
-static void
-createSimdLoop(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval,
- llvm::omp::Directive ompDirective,
- const Fortran::parser::OmpClauseList &loopOpClauseList,
- mlir::Location loc) {
+static mlir::omp::TargetDataOp
+genTargetDataOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, bool genNested,
+ mlir::Location loc,
+ const Fortran::parser::OmpClauseList &clauseList) {
+ Fortran::lower::StatementContext stmtCtx;
+ mlir::omp::TargetDataClauseOps clauseOps;
+ llvm::SmallVector<mlir::Type> useDeviceTypes;
+ llvm::SmallVector<mlir::Location> useDeviceLocs;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSyms;
+ genTargetDataClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps,
+ useDeviceTypes, useDeviceLocs, useDeviceSyms);
+
+ auto targetDataOp =
+ converter.getFirOpBuilder().create<mlir::omp::TargetDataOp>(loc,
+ clauseOps);
+
+ genBodyOfTargetDataOp(converter, semaCtx, eval, genNested, targetDataOp,
+ useDeviceTypes, useDeviceLocs, useDeviceSyms, loc);
+ return targetDataOp;
+}
+
+template <typename OpTy>
+static OpTy genTargetEnterExitUpdateDataOp(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx, mlir::Location loc,
+ const Fortran::parser::OmpClauseList &clauseList) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- DataSharingProcessor dsp(converter, semaCtx, loopOpClauseList, eval);
- dsp.processStep1();
+ Fortran::lower::StatementContext stmtCtx;
+
+ // GCC 9.3.0 emits a (probably) bogus warning about an unused variable.
+ [[maybe_unused]] llvm::omp::Directive directive;
+ if constexpr (std::is_same_v<OpTy, mlir::omp::TargetEnterDataOp>) {
+ directive = llvm::omp::Directive::OMPD_target_enter_data;
+ } else if constexpr (std::is_same_v<OpTy, mlir::omp::TargetExitDataOp>) {
+ directive = llvm::omp::Directive::OMPD_target_exit_data;
+ } else if constexpr (std::is_same_v<OpTy, mlir::omp::TargetUpdateOp>) {
+ directive = llvm::omp::Directive::OMPD_target_update;
+ } else {
+ llvm_unreachable("Unexpected TARGET DATA construct");
+ }
+
+ mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps;
+ genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauseList,
+ loc, directive, clauseOps);
+
+ return firOpBuilder.create<OpTy>(loc, clauseOps);
+}
+static mlir::omp::TaskOp
+genTaskOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, bool genNested,
+ mlir::Location loc,
+ const Fortran::parser::OmpClauseList &clauseList) {
Fortran::lower::StatementContext stmtCtx;
- mlir::omp::SimdLoopClauseOps clauseOps;
- llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
+ mlir::omp::TaskClauseOps clauseOps;
+ genTaskClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps);
- ClauseProcessor cp(converter, semaCtx, loopOpClauseList);
- cp.processCollapse(loc, eval, clauseOps, iv);
- cp.processReduction(loc, clauseOps);
- cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps);
- cp.processSimdlen(clauseOps);
- cp.processSafelen(clauseOps);
- clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr();
- // TODO Support delayed privatization.
+ return genOpWithBody<mlir::omp::TaskOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, eval)
+ .setGenNested(genNested)
+ .setClauses(&clauseList),
+ clauseOps);
+}
- cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
- clause::Nontemporal, clause::Order>(loc, ompDirective);
+static mlir::omp::TaskgroupOp
+genTaskgroupOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, bool genNested,
+ mlir::Location loc,
+ const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::omp::TaskgroupClauseOps clauseOps;
+ genTaskgroupClauses(converter, semaCtx, clauseList, loc, clauseOps);
- auto *nestedEval = getCollapsedLoopEval(
- eval, Fortran::lower::getCollapseValue(loopOpClauseList));
+ return genOpWithBody<mlir::omp::TaskgroupOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, eval)
+ .setGenNested(genNested)
+ .setClauses(&clauseList),
+ clauseOps);
+}
- auto ivCallback = [&](mlir::Operation *op) {
- return genLoopVars(op, converter, loc, iv);
- };
+static mlir::omp::TaskloopOp
+genTaskloopOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const Fortran::parser::OmpClauseList &clauseList) {
+ TODO(loc, "Taskloop construct");
+}
- genOpWithBody<mlir::omp::SimdLoopOp>(
- OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
- .setClauses(&loopOpClauseList)
- .setDataSharingProcessor(&dsp)
- .setGenRegionEntryCb(ivCallback),
+static mlir::omp::TaskwaitOp
+genTaskwaitOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::omp::TaskwaitClauseOps clauseOps;
+ genTaskwaitClauses(converter, semaCtx, clauseList, loc, clauseOps);
+ return converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(loc,
+ clauseOps);
+}
+
+static mlir::omp::TaskyieldOp
+genTaskyieldOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc) {
+ return converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(loc);
+}
+
+static mlir::omp::TeamsOp
+genTeamsOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, bool genNested,
+ mlir::Location loc, const Fortran::parser::OmpClauseList &clauseList,
+ bool outerCombined = false) {
+ Fortran::lower::StatementContext stmtCtx;
+ mlir::omp::TeamsClauseOps clauseOps;
+ genTeamsClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps);
+
+ return genOpWithBody<mlir::omp::TeamsOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, eval)
+ .setGenNested(genNested)
+ .setOuterCombined(outerCombined)
+ .setClauses(&clauseList),
clauseOps);
}
-static void createWsloop(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval,
- llvm::omp::Directive ompDirective,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList,
- mlir::Location loc) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+static mlir::omp::WsloopOp
+genWsloopOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const Fortran::parser::OmpClauseList &beginClauseList,
+ const Fortran::parser::OmpClauseList *endClauseList) {
DataSharingProcessor dsp(converter, semaCtx, beginClauseList, eval);
dsp.processStep1();
@@ -1471,30 +1731,9 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
llvm::SmallVector<mlir::Type> reductionTypes;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
-
- ClauseProcessor cp(converter, semaCtx, beginClauseList);
- cp.processCollapse(loc, eval, clauseOps, iv);
- cp.processSchedule(stmtCtx, clauseOps);
- cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
- cp.processOrdered(clauseOps);
- clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr();
- // TODO Support delayed privatization.
-
- if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
- clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
-
- cp.processTODO<clause::Allocate, clause::Linear, clause::Order>(loc,
- ompDirective);
-
- // In FORTRAN `nowait` clause occur at the end of `omp do` directive.
- // i.e
- // !$omp do
- // <...>
- // !$omp end do nowait
- if (endClauseList) {
- ClauseProcessor ecp(converter, semaCtx, *endClauseList);
- ecp.processNowait(clauseOps);
- }
+ genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauseList,
+ endClauseList, loc, clauseOps, iv, reductionTypes,
+ reductionSyms);
auto *nestedEval = getCollapsedLoopEval(
eval, Fortran::lower::getCollapseValue(beginClauseList));
@@ -1504,7 +1743,7 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter,
reductionTypes);
};
- genOpWithBody<mlir::omp::WsloopOp>(
+ return genOpWithBody<mlir::omp::WsloopOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
.setClauses(&beginClauseList)
.setDataSharingProcessor(&dsp)
@@ -1513,7 +1752,11 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter,
clauseOps);
}
-static void createSimdWsloop(
+//===----------------------------------------------------------------------===//
+// Code generation functions for composite constructs
+//===----------------------------------------------------------------------===//
+
+static void genCompositeDoSimd(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive ompDirective,
@@ -1521,7 +1764,7 @@ static void createSimdWsloop(
const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
ClauseProcessor cp(converter, semaCtx, beginClauseList);
cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
- clause::Safelen, clause::Simdlen, clause::Order>(loc,
+ clause::Order, clause::Safelen, clause::Simdlen>(loc,
ompDirective);
// TODO: Add support for vectorization - add vectorization hints inside loop
// body.
@@ -1531,34 +1774,7 @@ static void createSimdWsloop(
// When support for vectorization is enabled, then we need to add handling of
// if clause. Currently if clause can be skipped because we always assume
// SIMD length = 1.
- createWsloop(converter, semaCtx, eval, ompDirective, beginClauseList,
- endClauseList, loc);
-}
-
-static void
-markDeclareTarget(mlir::Operation *op,
- Fortran::lower::AbstractConverter &converter,
- mlir::omp::DeclareTargetCaptureClause captureClause,
- mlir::omp::DeclareTargetDeviceType deviceType) {
- // TODO: Add support for program local variables with declare target applied
- auto declareTargetOp = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(op);
- if (!declareTargetOp)
- fir::emitFatalError(
- converter.getCurrentLocation(),
- "Attempt to apply declare target on unsupported operation");
-
- // The function or global already has a declare target applied to it, very
- // likely through implicit capture (usage in another declare target
- // function/subroutine). It should be marked as any if it has been assigned
- // both host and nohost, else we skip, as there is no change
- if (declareTargetOp.isDeclareTarget()) {
- if (declareTargetOp.getDeclareTargetDeviceType() != deviceType)
- declareTargetOp.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any,
- captureClause);
- return;
- }
-
- declareTargetOp.setDeclareTarget(deviceType, captureClause);
+ genWsloopOp(converter, semaCtx, eval, loc, beginClauseList, endClauseList);
}
//===----------------------------------------------------------------------===//
@@ -1653,6 +1869,102 @@ genOMP(Fortran::lower::AbstractConverter &converter,
ompDeclConstruct.u);
}
+//===----------------------------------------------------------------------===//
+// OpenMPStandaloneConstruct visitors
+//===----------------------------------------------------------------------===//
+
+static void genOMP(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::SymMap &symTable,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenMPSimpleStandaloneConstruct
+ &simpleStandaloneConstruct) {
+ const auto &directive =
+ std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
+ simpleStandaloneConstruct.t);
+ const auto &clauseList =
+ std::get<Fortran::parser::OmpClauseList>(simpleStandaloneConstruct.t);
+ mlir::Location currentLocation = converter.genLocation(directive.source);
+
+ switch (directive.v) {
+ default:
+ break;
+ case llvm::omp::Directive::OMPD_barrier:
+ genBarrierOp(converter, semaCtx, eval, currentLocation);
+ break;
+ case llvm::omp::Directive::OMPD_taskwait:
+ genTaskwaitOp(converter, semaCtx, eval, currentLocation, clauseList);
+ break;
+ case llvm::omp::Directive::OMPD_taskyield:
+ genTaskyieldOp(converter, semaCtx, eval, currentLocation);
+ break;
+ case llvm::omp::Directive::OMPD_target_data:
+ genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true,
+ currentLocation, clauseList);
+ break;
+ case llvm::omp::Directive::OMPD_target_enter_data:
+ genTargetEnterExitUpdateDataOp<mlir::omp::TargetEnterDataOp>(
+ converter, semaCtx, currentLocation, clauseList);
+ break;
+ case llvm::omp::Directive::OMPD_target_exit_data:
+ genTargetEnterExitUpdateDataOp<mlir::omp::TargetExitDataOp>(
+ converter, semaCtx, currentLocation, clauseList);
+ break;
+ case llvm::omp::Directive::OMPD_target_update:
+ genTargetEnterExitUpdateDataOp<mlir::omp::TargetUpdateOp>(
+ converter, semaCtx, currentLocation, clauseList);
+ break;
+ case llvm::omp::Directive::OMPD_ordered:
+ genOrderedOp(converter, semaCtx, eval, currentLocation, clauseList);
+ break;
+ }
+}
+
+static void
+genOMP(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::SymMap &symTable,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
+ const auto &verbatim = std::get<Fortran::parser::Verbatim>(flushConstruct.t);
+ const auto &objectList =
+ std::get<std::optional<Fortran::parser::OmpObjectList>>(flushConstruct.t);
+ const auto &clauseList =
+ std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
+ flushConstruct.t);
+ mlir::Location currentLocation = converter.genLocation(verbatim.source);
+ genFlushOp(converter, semaCtx, eval, currentLocation, objectList, clauseList);
+}
+
+static void
+genOMP(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::SymMap &symTable,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) {
+ TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
+}
+
+static void genOMP(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::SymMap &symTable,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenMPCancellationPointConstruct
+ &cancellationPointConstruct) {
+ TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
+}
+
+static void
+genOMP(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::SymMap &symTable,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) {
+ std::visit(
+ [&](auto &&s) { return genOMP(converter, symTable, semaCtx, eval, s); },
+ standaloneConstruct.u);
+}
+
//===----------------------------------------------------------------------===//
// OpenMPConstruct visitors
//===----------------------------------------------------------------------===//
@@ -1782,7 +2094,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
break;
case llvm::omp::Directive::OMPD_target:
genTargetOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
- beginClauseList, directive.v);
+ beginClauseList);
break;
case llvm::omp::Directive::OMPD_target_data:
genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true,
@@ -1798,8 +2110,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
break;
case llvm::omp::Directive::OMPD_teams:
genTeamsOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
- beginClauseList,
- /*outerCombined=*/false);
+ beginClauseList);
break;
case llvm::omp::Directive::OMPD_workshare:
// FIXME: Workshare is not a commonly used OpenMP construct, an
@@ -1821,8 +2132,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
if ((llvm::omp::allTargetSet & llvm::omp::blockConstructSet)
.test(directive.v)) {
genTargetOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation,
- beginClauseList, directive.v,
- /*outerCombined=*/true);
+ beginClauseList, /*outerCombined=*/true);
combinedDirective = true;
}
if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet)
@@ -1859,44 +2169,13 @@ genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::Location currentLocation = converter.getCurrentLocation();
- std::string name;
- const Fortran::parser::OmpCriticalDirective &cd =
+ const auto &cd =
std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t);
- if (std::get<std::optional<Fortran::parser::Name>>(cd.t).has_value()) {
- name =
- std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString();
- }
-
- mlir::omp::CriticalOp criticalOp = [&]() {
- if (name.empty()) {
- return firOpBuilder.create<mlir::omp::CriticalOp>(
- currentLocation, mlir::FlatSymbolRefAttr());
- }
-
- mlir::ModuleOp module = firOpBuilder.getModule();
- mlir::OpBuilder modBuilder(module.getBodyRegion());
- auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name);
- if (!global) {
- mlir::omp::CriticalClauseOps clauseOps;
- const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
-
- ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processHint(clauseOps);
- clauseOps.nameAttr =
- mlir::StringAttr::get(firOpBuilder.getContext(), name);
-
- global = modBuilder.create<mlir::omp::CriticalDeclareOp>(currentLocation,
- clauseOps);
- }
-
- return firOpBuilder.create<mlir::omp::CriticalOp>(
- currentLocation, mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(),
- global.getSymName()));
- }();
- auto genInfo = OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval);
- createBodyOfOp<mlir::omp::CriticalOp>(criticalOp, genInfo);
+ const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
+ const auto &name = std::get<std::optional<Fortran::parser::Name>>(cd.t);
+ mlir::Location currentLocation = converter.getCurrentLocation();
+ genCriticalOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
+ clauseList, name);
}
static void
@@ -1915,7 +2194,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
const auto &beginLoopDirective =
std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t);
- const auto &loopOpClauseList =
+ const auto &beginClauseList =
std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t);
mlir::Location currentLocation =
converter.genLocation(beginLoopDirective.source);
@@ -1936,33 +2215,31 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
bool validDirective = false;
if (llvm::omp::topTaskloopSet.test(ompDirective)) {
validDirective = true;
- TODO(currentLocation, "Taskloop construct");
+ genTaskloopOp(converter, semaCtx, eval, currentLocation, beginClauseList);
} else {
// Create omp.{target, teams, distribute, parallel} nested operations
if ((llvm::omp::allTargetSet & llvm::omp::loopConstructSet)
.test(ompDirective)) {
validDirective = true;
genTargetOp(converter, semaCtx, eval, /*genNested=*/false,
- currentLocation, loopOpClauseList, ompDirective,
- /*outerCombined=*/true);
+ currentLocation, beginClauseList, /*outerCombined=*/true);
}
if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet)
.test(ompDirective)) {
validDirective = true;
genTeamsOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation,
- loopOpClauseList,
- /*outerCombined=*/true);
+ beginClauseList, /*outerCombined=*/true);
}
if (llvm::omp::allDistributeSet.test(ompDirective)) {
validDirective = true;
- TODO(currentLocation, "Distribute construct");
+ genDistributeOp(converter, semaCtx, eval, /*genNested=*/false,
+ currentLocation, beginClauseList);
}
if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet)
.test(ompDirective)) {
validDirective = true;
genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/false,
- currentLocation, loopOpClauseList,
- /*outerCombined=*/true);
+ currentLocation, beginClauseList, /*outerCombined=*/true);
}
}
if ((llvm::omp::allDoSet | llvm::omp::allSimdSet).test(ompDirective))
@@ -1976,17 +2253,15 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
if (llvm::omp::allDoSimdSet.test(ompDirective)) {
// 2.9.3.2 Workshare SIMD construct
- createSimdWsloop(converter, semaCtx, eval, ompDirective, loopOpClauseList,
- endClauseList, currentLocation);
-
+ genCompositeDoSimd(converter, semaCtx, eval, ompDirective, beginClauseList,
+ endClauseList, currentLocation);
} else if (llvm::omp::allSimdSet.test(ompDirective)) {
// 2.9.3.1 SIMD construct
- createSimdLoop(converter, semaCtx, eval, ompDirective, loopOpClauseList,
- currentLocation);
- genOpenMPReduction(converter, semaCtx, loopOpClauseList);
+ genSimdLoopOp(converter, semaCtx, eval, currentLocation, beginClauseList);
+ genOpenMPReduction(converter, semaCtx, beginClauseList);
} else {
- createWsloop(converter, semaCtx, eval, ompDirective, loopOpClauseList,
- endClauseList, currentLocation);
+ genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList,
+ endClauseList);
}
}
@@ -2006,44 +2281,39 @@ genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) {
- mlir::Location currentLocation = converter.getCurrentLocation();
- mlir::omp::SectionsClauseOps clauseOps;
const auto &beginSectionsDirective =
std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t);
- const auto §ionsClauseList =
+ const auto &beginClauseList =
std::get<Fortran::parser::OmpClauseList>(beginSectionsDirective.t);
// Process clauses before optional omp.parallel, so that new variables are
// allocated outside of the parallel region
- ClauseProcessor cp(converter, semaCtx, sectionsClauseList);
- cp.processSectionsReduction(currentLocation, clauseOps);
- cp.processAllocate(clauseOps);
- // TODO Support delayed privatization.
+ mlir::Location currentLocation = converter.getCurrentLocation();
+ mlir::omp::SectionsClauseOps clauseOps;
+ genSectionsClauses(converter, semaCtx, beginClauseList, currentLocation,
+ /*clausesFromBeginSections=*/true, clauseOps);
+ // Parallel wrapper of PARALLEL SECTIONS construct
llvm::omp::Directive dir =
std::get<Fortran::parser::OmpSectionsDirective>(beginSectionsDirective.t)
.v;
-
- // Parallel wrapper of PARALLEL SECTIONS construct
if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
genParallelOp(converter, symTable, semaCtx, eval,
- /*genNested=*/false, currentLocation, sectionsClauseList,
+ /*genNested=*/false, currentLocation, beginClauseList,
/*outerCombined=*/true);
} else {
const auto &endSectionsDirective =
std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t);
- const auto &endSectionsClauseList =
+ const auto &endClauseList =
std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t);
- ClauseProcessor(converter, semaCtx, endSectionsClauseList)
- .processNowait(clauseOps);
+ genSectionsClauses(converter, semaCtx, endClauseList, currentLocation,
+ /*clausesFromBeginSections=*/false, clauseOps);
}
- // SECTIONS construct
- genOpWithBody<mlir::omp::SectionsOp>(
- OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
- .setGenNested(false),
- clauseOps);
+ // SECTIONS construct.
+ genSectionsOp(converter, semaCtx, eval, currentLocation, clauseOps);
+ // Generate nested SECTION operations recursively.
const auto §ionBlocks =
std::get<Fortran::parser::OmpSectionBlocks>(sectionsConstruct.t);
auto &firOpBuilder = converter.getFirOpBuilder();
@@ -2052,40 +2322,12 @@ genOMP(Fortran::lower::AbstractConverter &converter,
llvm::zip(sectionBlocks.v, eval.getNestedEvaluations())) {
symTable.pushScope();
genSectionOp(converter, semaCtx, neval, /*genNested=*/true, currentLocation,
- sectionsClauseList);
+ beginClauseList);
symTable.popScope();
firOpBuilder.restoreInsertionPoint(ip);
}
}
-static void
-genOMP(Fortran::lower::AbstractConverter &converter,
- Fortran::lower::SymMap &symTable,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) {
- std::visit(
- Fortran::common::visitors{
- [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct
- &simpleStandaloneConstruct) {
- genOmpSimpleStandalone(converter, semaCtx, eval,
- /*genNested=*/true,
- simpleStandaloneConstruct);
- },
- [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
- genOmpFlush(converter, semaCtx, eval, flushConstruct);
- },
- [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) {
- TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
- },
- [&](const Fortran::parser::OpenMPCancellationPointConstruct
- &cancellationPointConstruct) {
- TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
- },
- },
- standaloneConstruct.u);
-}
-
static void genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
diff --git a/flang/test/Lower/OpenMP/FIR/target.f90 b/flang/test/Lower/OpenMP/FIR/target.f90
index 821196b83c3b99..d3f2a1c7a15936 100644
--- a/flang/test/Lower/OpenMP/FIR/target.f90
+++ b/flang/test/Lower/OpenMP/FIR/target.f90
@@ -411,8 +411,8 @@ end subroutine omp_target_implicit_bounds
!CHECK-LABEL: func.func @_QPomp_target_thread_limit() {
subroutine omp_target_thread_limit
integer :: a
- !CHECK: %[[VAL_1:.*]] = arith.constant 64 : i32
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}}) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "a"}
+ !CHECK: %[[VAL_1:.*]] = arith.constant 64 : i32
!CHECK: omp.target thread_limit(%[[VAL_1]] : i32) map_entries(%[[MAP]] -> %[[ARG_0:.*]] : !fir.ref<i32>) {
!CHECK: ^bb0(%[[ARG_0]]: !fir.ref<i32>):
!$omp target map(tofrom: a) thread_limit(64)
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 6f72b5a34d069a..51b66327dfb24b 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -490,8 +490,8 @@ end subroutine omp_target_implicit_bounds
!CHECK-LABEL: func.func @_QPomp_target_thread_limit() {
subroutine omp_target_thread_limit
integer :: a
- !CHECK: %[[VAL_1:.*]] = arith.constant 64 : i32
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}}) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "a"}
+ !CHECK: %[[VAL_1:.*]] = arith.constant 64 : i32
!CHECK: omp.target thread_limit(%[[VAL_1]] : i32) map_entries(%[[MAP]] -> %{{.*}} : !fir.ref<i32>) {
!CHECK: ^bb0(%{{.*}}: !fir.ref<i32>):
!$omp target map(tofrom: a) thread_limit(64)
diff --git a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
index 33b5971656010a..d849dd206b9439 100644
--- a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
+++ b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
@@ -21,7 +21,7 @@ subroutine only_use_device_ptr
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr()
!CHECK: omp.target_data use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) use_device_addr(%{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) {
-!CHECK: ^bb0(%{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>):
+!CHECK: ^bb0(%{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
subroutine mix_use_device_ptr_and_addr
use iso_c_binding
integer, pointer, dimension(:) :: array
@@ -47,7 +47,7 @@ subroutine only_use_device_addr
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr_and_map()
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) use_device_addr(%{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) {
-!CHECK: ^bb0(%{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>):
+!CHECK: ^bb0(%{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
subroutine mix_use_device_ptr_and_addr_and_map
use iso_c_binding
integer :: i, j
>From ec0ed50b0d5f9606f0e9a1a3a9999f601bec310f Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 29 Mar 2024 13:57:40 +0000
Subject: [PATCH 04/17] [Flang][OpenMP][Lower] Refactor lowering of compound
constructs
This patch simplifies the lowering from PFT to MLIR of OpenMP compound
constructs (i.e. combined and composite).
The new approach consists of iteratively processing the outermost leaf
construct of the given combined construct until it cannot be split further.
Both leaf constructs and composite ones have `gen...()` functions that are
called when appropriate.
This approach enables treating a leaf construct the same way regardless of if
it appeared as part of a combined construct, and it also enables the lowering
of composite constructs as a single unit.
Previous corner cases are now handled in a more straightforward way and
comments pointing to the relevant spec section are added. Directive sets are
also completed with missing LOOP related constructs.
---
.../flang/Semantics/openmp-directive-sets.h | 57 ++-
flang/lib/Lower/OpenMP/OpenMP.cpp | 432 ++++++++++++------
2 files changed, 335 insertions(+), 154 deletions(-)
diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index 91773ae3ea9a3e..842d251b682aa9 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -32,14 +32,14 @@ static const OmpDirectiveSet topDistributeSet{
static const OmpDirectiveSet allDistributeSet{
OmpDirectiveSet{
- llvm::omp::OMPD_target_teams_distribute,
- llvm::omp::OMPD_target_teams_distribute_parallel_do,
- llvm::omp::OMPD_target_teams_distribute_parallel_do_simd,
- llvm::omp::OMPD_target_teams_distribute_simd,
- llvm::omp::OMPD_teams_distribute,
- llvm::omp::OMPD_teams_distribute_parallel_do,
- llvm::omp::OMPD_teams_distribute_parallel_do_simd,
- llvm::omp::OMPD_teams_distribute_simd,
+ Directive::OMPD_target_teams_distribute,
+ Directive::OMPD_target_teams_distribute_parallel_do,
+ Directive::OMPD_target_teams_distribute_parallel_do_simd,
+ Directive::OMPD_target_teams_distribute_simd,
+ Directive::OMPD_teams_distribute,
+ Directive::OMPD_teams_distribute_parallel_do,
+ Directive::OMPD_teams_distribute_parallel_do_simd,
+ Directive::OMPD_teams_distribute_simd,
} | topDistributeSet,
};
@@ -63,10 +63,24 @@ static const OmpDirectiveSet allDoSet{
} | topDoSet,
};
+static const OmpDirectiveSet topLoopSet{
+ Directive::OMPD_loop,
+};
+
+static const OmpDirectiveSet allLoopSet{
+ OmpDirectiveSet{
+ Directive::OMPD_parallel_loop,
+ Directive::OMPD_target_parallel_loop,
+ Directive::OMPD_target_teams_loop,
+ Directive::OMPD_teams_loop,
+ } | topLoopSet,
+};
+
static const OmpDirectiveSet topParallelSet{
Directive::OMPD_parallel,
Directive::OMPD_parallel_do,
Directive::OMPD_parallel_do_simd,
+ Directive::OMPD_parallel_loop,
Directive::OMPD_parallel_masked_taskloop,
Directive::OMPD_parallel_masked_taskloop_simd,
Directive::OMPD_parallel_master_taskloop,
@@ -82,6 +96,7 @@ static const OmpDirectiveSet allParallelSet{
Directive::OMPD_target_parallel,
Directive::OMPD_target_parallel_do,
Directive::OMPD_target_parallel_do_simd,
+ Directive::OMPD_target_parallel_loop,
Directive::OMPD_target_teams_distribute_parallel_do,
Directive::OMPD_target_teams_distribute_parallel_do_simd,
Directive::OMPD_teams_distribute_parallel_do,
@@ -118,12 +133,14 @@ static const OmpDirectiveSet topTargetSet{
Directive::OMPD_target_parallel,
Directive::OMPD_target_parallel_do,
Directive::OMPD_target_parallel_do_simd,
+ Directive::OMPD_target_parallel_loop,
Directive::OMPD_target_simd,
Directive::OMPD_target_teams,
Directive::OMPD_target_teams_distribute,
Directive::OMPD_target_teams_distribute_parallel_do,
Directive::OMPD_target_teams_distribute_parallel_do_simd,
Directive::OMPD_target_teams_distribute_simd,
+ Directive::OMPD_target_teams_loop,
};
static const OmpDirectiveSet allTargetSet{topTargetSet};
@@ -156,11 +173,12 @@ static const OmpDirectiveSet topTeamsSet{
static const OmpDirectiveSet allTeamsSet{
OmpDirectiveSet{
- llvm::omp::OMPD_target_teams,
- llvm::omp::OMPD_target_teams_distribute,
- llvm::omp::OMPD_target_teams_distribute_parallel_do,
- llvm::omp::OMPD_target_teams_distribute_parallel_do_simd,
- llvm::omp::OMPD_target_teams_distribute_simd,
+ Directive::OMPD_target_teams,
+ Directive::OMPD_target_teams_distribute,
+ Directive::OMPD_target_teams_distribute_parallel_do,
+ Directive::OMPD_target_teams_distribute_parallel_do_simd,
+ Directive::OMPD_target_teams_distribute_simd,
+ Directive::OMPD_target_teams_loop,
} | topTeamsSet,
};
@@ -178,6 +196,14 @@ static const OmpDirectiveSet allDistributeSimdSet{
static const OmpDirectiveSet allDoSimdSet{allDoSet & allSimdSet};
static const OmpDirectiveSet allTaskloopSimdSet{allTaskloopSet & allSimdSet};
+static const OmpDirectiveSet compositeConstructSet{
+ Directive::OMPD_distribute_parallel_do,
+ Directive::OMPD_distribute_parallel_do_simd,
+ Directive::OMPD_distribute_simd,
+ Directive::OMPD_do_simd,
+ Directive::OMPD_taskloop_simd,
+};
+
static const OmpDirectiveSet blockConstructSet{
Directive::OMPD_master,
Directive::OMPD_ordered,
@@ -201,12 +227,14 @@ static const OmpDirectiveSet loopConstructSet{
Directive::OMPD_distribute_simd,
Directive::OMPD_do,
Directive::OMPD_do_simd,
+ Directive::OMPD_loop,
Directive::OMPD_masked_taskloop,
Directive::OMPD_masked_taskloop_simd,
Directive::OMPD_master_taskloop,
Directive::OMPD_master_taskloop_simd,
Directive::OMPD_parallel_do,
Directive::OMPD_parallel_do_simd,
+ Directive::OMPD_parallel_loop,
Directive::OMPD_parallel_masked_taskloop,
Directive::OMPD_parallel_masked_taskloop_simd,
Directive::OMPD_parallel_master_taskloop,
@@ -214,17 +242,20 @@ static const OmpDirectiveSet loopConstructSet{
Directive::OMPD_simd,
Directive::OMPD_target_parallel_do,
Directive::OMPD_target_parallel_do_simd,
+ Directive::OMPD_target_parallel_loop,
Directive::OMPD_target_simd,
Directive::OMPD_target_teams_distribute,
Directive::OMPD_target_teams_distribute_parallel_do,
Directive::OMPD_target_teams_distribute_parallel_do_simd,
Directive::OMPD_target_teams_distribute_simd,
+ Directive::OMPD_target_teams_loop,
Directive::OMPD_taskloop,
Directive::OMPD_taskloop_simd,
Directive::OMPD_teams_distribute,
Directive::OMPD_teams_distribute_parallel_do,
Directive::OMPD_teams_distribute_parallel_do_simd,
Directive::OMPD_teams_distribute_simd,
+ Directive::OMPD_teams_loop,
Directive::OMPD_tile,
Directive::OMPD_unroll,
};
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 692d81f9188be3..edae453972d3d9 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -710,6 +710,81 @@ genOpenMPReduction(Fortran::lower::AbstractConverter &converter,
}
}
+/// Split a combined directive into an outer leaf directive and the (possibly
+/// combined) rest of the combined directive. Composite directives and
+/// non-compound directives are not split, in which case it will return the
+/// input directive as its first output and an empty value as its second output.
+static std::pair<llvm::omp::Directive, std::optional<llvm::omp::Directive>>
+splitCombinedDirective(llvm::omp::Directive dir) {
+ using D = llvm::omp::Directive;
+ switch (dir) {
+ case D::OMPD_masked_taskloop:
+ return {D::OMPD_masked, D::OMPD_taskloop};
+ case D::OMPD_masked_taskloop_simd:
+ return {D::OMPD_masked, D::OMPD_taskloop_simd};
+ case D::OMPD_master_taskloop:
+ return {D::OMPD_master, D::OMPD_taskloop};
+ case D::OMPD_master_taskloop_simd:
+ return {D::OMPD_master, D::OMPD_taskloop_simd};
+ case D::OMPD_parallel_do:
+ return {D::OMPD_parallel, D::OMPD_do};
+ case D::OMPD_parallel_do_simd:
+ return {D::OMPD_parallel, D::OMPD_do_simd};
+ case D::OMPD_parallel_masked:
+ return {D::OMPD_parallel, D::OMPD_masked};
+ case D::OMPD_parallel_masked_taskloop:
+ return {D::OMPD_parallel, D::OMPD_masked_taskloop};
+ case D::OMPD_parallel_masked_taskloop_simd:
+ return {D::OMPD_parallel, D::OMPD_masked_taskloop_simd};
+ case D::OMPD_parallel_master:
+ return {D::OMPD_parallel, D::OMPD_master};
+ case D::OMPD_parallel_master_taskloop:
+ return {D::OMPD_parallel, D::OMPD_master_taskloop};
+ case D::OMPD_parallel_master_taskloop_simd:
+ return {D::OMPD_parallel, D::OMPD_master_taskloop_simd};
+ case D::OMPD_parallel_sections:
+ return {D::OMPD_parallel, D::OMPD_sections};
+ case D::OMPD_parallel_workshare:
+ return {D::OMPD_parallel, D::OMPD_workshare};
+ case D::OMPD_target_parallel:
+ return {D::OMPD_target, D::OMPD_parallel};
+ case D::OMPD_target_parallel_do:
+ return {D::OMPD_target, D::OMPD_parallel_do};
+ case D::OMPD_target_parallel_do_simd:
+ return {D::OMPD_target, D::OMPD_parallel_do_simd};
+ case D::OMPD_target_simd:
+ return {D::OMPD_target, D::OMPD_simd};
+ case D::OMPD_target_teams:
+ return {D::OMPD_target, D::OMPD_teams};
+ case D::OMPD_target_teams_distribute:
+ return {D::OMPD_target, D::OMPD_teams_distribute};
+ case D::OMPD_target_teams_distribute_parallel_do:
+ return {D::OMPD_target, D::OMPD_teams_distribute_parallel_do};
+ case D::OMPD_target_teams_distribute_parallel_do_simd:
+ return {D::OMPD_target, D::OMPD_teams_distribute_parallel_do_simd};
+ case D::OMPD_target_teams_distribute_simd:
+ return {D::OMPD_target, D::OMPD_teams_distribute_simd};
+ case D::OMPD_teams_distribute:
+ return {D::OMPD_teams, D::OMPD_distribute};
+ case D::OMPD_teams_distribute_parallel_do:
+ return {D::OMPD_teams, D::OMPD_distribute_parallel_do};
+ case D::OMPD_teams_distribute_parallel_do_simd:
+ return {D::OMPD_teams, D::OMPD_distribute_parallel_do_simd};
+ case D::OMPD_teams_distribute_simd:
+ return {D::OMPD_teams, D::OMPD_distribute_simd};
+ case D::OMPD_parallel_loop:
+ return {D::OMPD_parallel, D::OMPD_loop};
+ case D::OMPD_target_parallel_loop:
+ return {D::OMPD_target, D::OMPD_parallel_loop};
+ case D::OMPD_target_teams_loop:
+ return {D::OMPD_target, D::OMPD_teams_loop};
+ case D::OMPD_teams_loop:
+ return {D::OMPD_teams, D::OMPD_loop};
+ default:
+ return {dir, std::nullopt};
+ }
+}
+
//===----------------------------------------------------------------------===//
// Op body generation helper structures and functions
//===----------------------------------------------------------------------===//
@@ -1962,16 +2037,44 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
// Code generation functions for composite constructs
//===----------------------------------------------------------------------===//
-static void genCompositeDoSimd(
+static void genCompositeDistributeParallelDo(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OmpClauseList &beginClauseList,
+ const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+ TODO(loc, "Composite DISTRIBUTE PARALLEL DO");
+}
+
+static void genCompositeDistributeParallelDoSimd(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OmpClauseList &beginClauseList,
+ const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+ TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD");
+}
+
+static void genCompositeDistributeSimd(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive ompDirective,
+ Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OmpClauseList &beginClauseList,
const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+ TODO(loc, "Composite DISTRIBUTE SIMD");
+}
+
+static void
+genCompositeDoSimd(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OmpClauseList &beginClauseList,
+ const Fortran::parser::OmpClauseList *endClauseList,
+ mlir::Location loc) {
ClauseProcessor cp(converter, semaCtx, beginClauseList);
cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
- clause::Order, clause::Safelen, clause::Simdlen>(loc,
- ompDirective);
+ clause::Order, clause::Safelen, clause::Simdlen>(
+ loc, llvm::omp::OMPD_do_simd);
// TODO: Add support for vectorization - add vectorization hints inside loop
// body.
// OpenMP standard does not specify the length of vector instructions.
@@ -1983,6 +2086,16 @@ static void genCompositeDoSimd(
genWsloopOp(converter, semaCtx, eval, loc, beginClauseList, endClauseList);
}
+static void
+genCompositeTaskloopSimd(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OmpClauseList &beginClauseList,
+ const Fortran::parser::OmpClauseList *endClauseList,
+ mlir::Location loc) {
+ TODO(loc, "Composite TASKLOOP SIMD");
+}
+
//===----------------------------------------------------------------------===//
// OpenMPDeclarativeConstruct visitors
//===----------------------------------------------------------------------===//
@@ -2240,13 +2353,18 @@ genOMP(Fortran::lower::AbstractConverter &converter,
std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
const auto &endBlockDirective =
std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t);
- const auto &directive =
- std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
+ mlir::Location currentLocation =
+ converter.genLocation(beginBlockDirective.source);
+ const auto origDirective =
+ std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t).v;
const auto &beginClauseList =
std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
const auto &endClauseList =
std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t);
+ assert(llvm::omp::blockConstructSet.test(origDirective) &&
+ "Expected block construct");
+
for (const Fortran::parser::OmpClause &clause : beginClauseList.v) {
mlir::Location clauseLocation = converter.genLocation(clause.source);
if (!std::get_if<Fortran::parser::OmpClause::If>(&clause.u) &&
@@ -2280,93 +2398,74 @@ genOMP(Fortran::lower::AbstractConverter &converter,
TODO(clauseLocation, "OpenMP Block construct clause");
}
- bool singleDirective = true;
- mlir::Location currentLocation = converter.genLocation(directive.source);
- switch (directive.v) {
- case llvm::omp::Directive::OMPD_master:
- genMasterOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation);
- break;
- case llvm::omp::Directive::OMPD_ordered:
- genOrderedRegionOp(converter, semaCtx, eval, /*genNested=*/true,
- currentLocation, beginClauseList);
- break;
- case llvm::omp::Directive::OMPD_parallel:
- genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/true,
- currentLocation, beginClauseList);
- break;
- case llvm::omp::Directive::OMPD_single:
- genSingleOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
- beginClauseList, endClauseList);
- break;
- case llvm::omp::Directive::OMPD_target:
- genTargetOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
+ std::optional<llvm::omp::Directive> nextDir = origDirective;
+ bool outermostLeafConstruct = true;
+ while (nextDir) {
+ llvm::omp::Directive leafDir;
+ std::tie(leafDir, nextDir) = splitCombinedDirective(*nextDir);
+ const bool genNested = !nextDir;
+ const bool outerCombined = outermostLeafConstruct && nextDir.has_value();
+ switch (leafDir) {
+ case llvm::omp::Directive::OMPD_master:
+ // 2.16 MASTER construct.
+ genMasterOp(converter, semaCtx, eval, genNested, currentLocation);
+ break;
+ case llvm::omp::Directive::OMPD_ordered:
+ // 2.17.9 ORDERED construct.
+ genOrderedRegionOp(converter, semaCtx, eval, genNested, currentLocation,
+ beginClauseList);
+ break;
+ case llvm::omp::Directive::OMPD_parallel:
+ // 2.6 PARALLEL construct.
+ genParallelOp(converter, symTable, semaCtx, eval, genNested,
+ currentLocation, beginClauseList, outerCombined);
+ break;
+ case llvm::omp::Directive::OMPD_single:
+ // 2.8.2 SINGLE construct.
+ genSingleOp(converter, semaCtx, eval, genNested, currentLocation,
+ beginClauseList, endClauseList);
+ break;
+ case llvm::omp::Directive::OMPD_target:
+ // 2.12.5 TARGET construct.
+ genTargetOp(converter, semaCtx, eval, genNested, currentLocation,
+ beginClauseList, outerCombined);
+ break;
+ case llvm::omp::Directive::OMPD_target_data:
+ // 2.12.2 TARGET DATA construct.
+ genTargetDataOp(converter, semaCtx, eval, genNested, currentLocation,
+ beginClauseList);
+ break;
+ case llvm::omp::Directive::OMPD_task:
+ // 2.10.1 TASK construct.
+ genTaskOp(converter, semaCtx, eval, genNested, currentLocation,
beginClauseList);
- break;
- case llvm::omp::Directive::OMPD_target_data:
- genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true,
- currentLocation, beginClauseList);
- break;
- case llvm::omp::Directive::OMPD_task:
- genTaskOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
- beginClauseList);
- break;
- case llvm::omp::Directive::OMPD_taskgroup:
- genTaskgroupOp(converter, semaCtx, eval, /*genNested=*/true,
- currentLocation, beginClauseList);
- break;
- case llvm::omp::Directive::OMPD_teams:
- genTeamsOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
- beginClauseList);
- break;
- case llvm::omp::Directive::OMPD_workshare:
- // FIXME: Workshare is not a commonly used OpenMP construct, an
- // implementation for this feature will come later. For the codes
- // that use this construct, add a single construct for now.
- genSingleOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
- beginClauseList, endClauseList);
- break;
- default:
- singleDirective = false;
- break;
- }
-
- if (singleDirective)
- return;
-
- // Codegen for combined directives
- bool combinedDirective = false;
- if ((llvm::omp::allTargetSet & llvm::omp::blockConstructSet)
- .test(directive.v)) {
- genTargetOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation,
- beginClauseList, /*outerCombined=*/true);
- combinedDirective = true;
- }
- if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet)
- .test(directive.v)) {
- genTeamsOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation,
- beginClauseList);
- combinedDirective = true;
- }
- if ((llvm::omp::allParallelSet & llvm::omp::blockConstructSet)
- .test(directive.v)) {
- bool outerCombined =
- directive.v != llvm::omp::Directive::OMPD_target_parallel;
- genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/false,
- currentLocation, beginClauseList, outerCombined);
- combinedDirective = true;
- }
- if ((llvm::omp::workShareSet & llvm::omp::blockConstructSet)
- .test(directive.v)) {
- genSingleOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation,
- beginClauseList, endClauseList);
- combinedDirective = true;
+ break;
+ case llvm::omp::Directive::OMPD_taskgroup:
+ // 2.17.6 TASKGROUP construct.
+ genTaskgroupOp(converter, semaCtx, eval, genNested, currentLocation,
+ beginClauseList);
+ break;
+ case llvm::omp::Directive::OMPD_teams:
+ // 2.7 TEAMS construct.
+ // FIXME Pass the outerCombined argument or rename it to better describe
+ // what it represents if it must always be `false` in this context.
+ genTeamsOp(converter, semaCtx, eval, genNested, currentLocation,
+ beginClauseList);
+ break;
+ case llvm::omp::Directive::OMPD_workshare:
+ // 2.8.3 WORKSHARE construct.
+ // FIXME: Workshare is not a commonly used OpenMP construct, an
+ // implementation for this feature will come later. For the codes
+ // that use this construct, add a single construct for now.
+ genSingleOp(converter, semaCtx, eval, genNested, currentLocation,
+ beginClauseList, endClauseList);
+ break;
+ default:
+ llvm_unreachable("Unexpected block construct");
+ break;
+ }
+ outermostLeafConstruct = false;
}
- if (!combinedDirective)
- TODO(currentLocation, "Unhandled block directive (" +
- llvm::omp::getOpenMPDirectiveName(directive.v) +
- ")");
-
- genNestedEvaluations(converter, eval);
}
static void
@@ -2404,9 +2503,12 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t);
mlir::Location currentLocation =
converter.genLocation(beginLoopDirective.source);
- const auto ompDirective =
+ const auto origDirective =
std::get<Fortran::parser::OmpLoopDirective>(beginLoopDirective.t).v;
+ assert(llvm::omp::loopConstructSet.test(origDirective) &&
+ "Expected loop construct");
+
const auto *endClauseList = [&]() {
using RetTy = const Fortran::parser::OmpClauseList *;
if (auto &endLoopDirective =
@@ -2418,57 +2520,105 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
return RetTy();
}();
- bool validDirective = false;
- if (llvm::omp::topTaskloopSet.test(ompDirective)) {
- validDirective = true;
- genTaskloopOp(converter, semaCtx, eval, currentLocation, beginClauseList);
- } else {
- // Create omp.{target, teams, distribute, parallel} nested operations
- if ((llvm::omp::allTargetSet & llvm::omp::loopConstructSet)
- .test(ompDirective)) {
- validDirective = true;
- genTargetOp(converter, semaCtx, eval, /*genNested=*/false,
- currentLocation, beginClauseList, /*outerCombined=*/true);
- }
- if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet)
- .test(ompDirective)) {
- validDirective = true;
- genTeamsOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation,
- beginClauseList, /*outerCombined=*/true);
- }
- if (llvm::omp::allDistributeSet.test(ompDirective)) {
- validDirective = true;
- genDistributeOp(converter, semaCtx, eval, /*genNested=*/false,
- currentLocation, beginClauseList);
- }
- if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet)
- .test(ompDirective)) {
- validDirective = true;
- genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/false,
- currentLocation, beginClauseList, /*outerCombined=*/true);
+ std::optional<llvm::omp::Directive> nextDir = origDirective;
+ while (nextDir) {
+ llvm::omp::Directive leafDir;
+ std::tie(leafDir, nextDir) = splitCombinedDirective(*nextDir);
+ if (llvm::omp::compositeConstructSet.test(leafDir)) {
+ assert(!nextDir && "Composite construct cannot be split");
+ switch (leafDir) {
+ case llvm::omp::Directive::OMPD_distribute_parallel_do:
+ // 2.9.4.3 DISTRIBUTE PARALLEL Worksharing-Loop construct.
+ genCompositeDistributeParallelDo(converter, semaCtx, eval,
+ beginClauseList, endClauseList,
+ currentLocation);
+ break;
+ case llvm::omp::Directive::OMPD_distribute_parallel_do_simd:
+ // 2.9.4.4 DISTRIBUTE PARALLEL Worksharing-Loop SIMD construct.
+ genCompositeDistributeParallelDoSimd(converter, semaCtx, eval,
+ beginClauseList, endClauseList,
+ currentLocation);
+ break;
+ case llvm::omp::Directive::OMPD_distribute_simd:
+ // 2.9.4.2 DISTRIBUTE SIMD construct.
+ genCompositeDistributeSimd(converter, semaCtx, eval, beginClauseList,
+ endClauseList, currentLocation);
+ break;
+ case llvm::omp::Directive::OMPD_do_simd:
+ // 2.9.3.2 Worksharing-Loop SIMD construct.
+ genCompositeDoSimd(converter, semaCtx, eval, beginClauseList,
+ endClauseList, currentLocation);
+ break;
+ case llvm::omp::Directive::OMPD_taskloop_simd:
+ // 2.10.3 TASKLOOP SIMD construct.
+ genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauseList,
+ endClauseList, currentLocation);
+ break;
+ default:
+ llvm_unreachable("Unexpected composite construct");
+ }
+ } else {
+ const bool genNested = !nextDir;
+ switch (leafDir) {
+ case llvm::omp::Directive::OMPD_distribute:
+ // 2.9.4.1 DISTRIBUTE construct.
+ genDistributeOp(converter, semaCtx, eval, genNested, currentLocation,
+ beginClauseList);
+ break;
+ case llvm::omp::Directive::OMPD_do:
+ // 2.9.2 Worksharing-Loop construct.
+ genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList,
+ endClauseList);
+ break;
+ case llvm::omp::Directive::OMPD_parallel:
+ // 2.6 PARALLEL construct.
+ // FIXME This is not necessarily always the outer leaf construct of a
+ // combined construct in this constext (e.g. distribute parallel do).
+ // Maybe rename the argument if it represents something else or
+ // initialize it properly.
+ genParallelOp(converter, symTable, semaCtx, eval, genNested,
+ currentLocation, beginClauseList,
+ /*outerCombined=*/true);
+ break;
+ case llvm::omp::Directive::OMPD_simd:
+ // 2.9.3.1 SIMD construct.
+ genSimdLoopOp(converter, semaCtx, eval, currentLocation,
+ beginClauseList);
+ genOpenMPReduction(converter, semaCtx, beginClauseList);
+ break;
+ case llvm::omp::Directive::OMPD_target:
+ // 2.12.5 TARGET construct.
+ genTargetOp(converter, semaCtx, eval, genNested, currentLocation,
+ beginClauseList, /*outerCombined=*/true);
+ break;
+ case llvm::omp::Directive::OMPD_taskloop:
+ // 2.10.2 TASKLOOP construct.
+ genTaskloopOp(converter, semaCtx, eval, currentLocation,
+ beginClauseList);
+ break;
+ case llvm::omp::Directive::OMPD_teams:
+ // 2.7 TEAMS construct.
+ // FIXME This is not necessarily always the outer leaf construct of a
+ // combined construct in this constext (e.g. target teams distribute).
+ // Maybe rename the argument if it represents something else or
+ // initialize it properly.
+ genTeamsOp(converter, semaCtx, eval, genNested, currentLocation,
+ beginClauseList, /*outerCombined=*/true);
+ break;
+ case llvm::omp::Directive::OMPD_loop:
+ case llvm::omp::Directive::OMPD_masked:
+ case llvm::omp::Directive::OMPD_master:
+ case llvm::omp::Directive::OMPD_tile:
+ case llvm::omp::Directive::OMPD_unroll:
+ TODO(currentLocation, "Unhandled loop directive (" +
+ llvm::omp::getOpenMPDirectiveName(leafDir) +
+ ")");
+ break;
+ default:
+ llvm_unreachable("Unexpected loop construct");
+ }
}
}
- if ((llvm::omp::allDoSet | llvm::omp::allSimdSet).test(ompDirective))
- validDirective = true;
-
- if (!validDirective) {
- TODO(currentLocation, "Unhandled loop directive (" +
- llvm::omp::getOpenMPDirectiveName(ompDirective) +
- ")");
- }
-
- if (llvm::omp::allDoSimdSet.test(ompDirective)) {
- // 2.9.3.2 Workshare SIMD construct
- genCompositeDoSimd(converter, semaCtx, eval, ompDirective, beginClauseList,
- endClauseList, currentLocation);
- } else if (llvm::omp::allSimdSet.test(ompDirective)) {
- // 2.9.3.1 SIMD construct
- genSimdLoopOp(converter, semaCtx, eval, currentLocation, beginClauseList);
- genOpenMPReduction(converter, semaCtx, beginClauseList);
- } else {
- genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList,
- endClauseList);
- }
}
static void
>From f725face892cef4faf9f17d4b549541bdbcd7e08 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Fri, 29 Mar 2024 09:20:41 -0500
Subject: [PATCH 05/17] [flang][OpenMP] Move clause/object conversion to happen
early, in genOMP
This removes the last use of genOmpObectList2, which has now been removed.
---
flang/lib/Lower/OpenMP/ClauseProcessor.h | 5 +-
flang/lib/Lower/OpenMP/DataSharingProcessor.h | 5 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 424 +++++++++---------
flang/lib/Lower/OpenMP/Utils.cpp | 30 +-
flang/lib/Lower/OpenMP/Utils.h | 6 +-
5 files changed, 218 insertions(+), 252 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index db7a1b8335f818..f4d659b70cfee7 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -49,9 +49,8 @@ class ClauseProcessor {
public:
ClauseProcessor(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses)
- : converter(converter), semaCtx(semaCtx),
- clauses(makeClauses(clauses, semaCtx)) {}
+ const List<Clause> &clauses)
+ : converter(converter), semaCtx(semaCtx), clauses(clauses) {}
// 'Unique' clauses: They can appear at most once in the clause list.
bool processCollapse(
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
index c11ee299c5d085..ef7b14327278e3 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
@@ -78,13 +78,12 @@ class DataSharingProcessor {
public:
DataSharingProcessor(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &opClauseList,
+ const List<Clause> &clauses,
Fortran::lower::pft::Evaluation &eval,
bool useDelayedPrivatization = false,
Fortran::lower::SymMap *symTable = nullptr)
: hasLastPrivateOp(false), converter(converter),
- firOpBuilder(converter.getFirOpBuilder()),
- clauses(omp::makeClauses(opClauseList, semaCtx)), eval(eval),
+ firOpBuilder(converter.getFirOpBuilder()), clauses(clauses), eval(eval),
useDelayedPrivatization(useDelayedPrivatization), symTable(symTable) {}
// Privatisation is split into two steps.
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index edae453972d3d9..23dc25ac1ae9a1 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -17,6 +17,7 @@
#include "DataSharingProcessor.h"
#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
+#include "Utils.h"
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
@@ -310,14 +311,15 @@ static void getDeclareTargetInfo(
} else if (const auto *clauseList{
Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
spec.u)}) {
- if (clauseList->v.empty()) {
+ List<Clause> clauses = makeClauses(*clauseList, semaCtx);
+ if (clauses.empty()) {
// Case: declare target, implicit capture of function
symbolAndClause.emplace_back(
mlir::omp::DeclareTargetCaptureClause::to,
eval.getOwningProcedure()->getSubprogramSymbol());
}
- ClauseProcessor cp(converter, semaCtx, *clauseList);
+ ClauseProcessor cp(converter, semaCtx, clauses);
cp.processDeviceType(clauseOps);
cp.processEnter(symbolAndClause);
cp.processLink(symbolAndClause);
@@ -597,14 +599,11 @@ static void removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) {
// TODO: Generate the reduction operation during lowering instead of creating
// and removing operations since this is not a robust approach. Also, removing
// ops in the builder (instead of a rewriter) is probably not the best approach.
-static void
-genOpenMPReduction(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauseList) {
+static void genOpenMPReduction(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const List<Clause> &clauses) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- List<Clause> clauses{makeClauses(clauseList, semaCtx)};
-
for (const Clause &clause : clauses) {
if (const auto &reductionClause =
std::get_if<clause::Reduction>(&clause.u)) {
@@ -812,7 +811,7 @@ struct OpWithBodyGenInfo {
return *this;
}
- OpWithBodyGenInfo &setClauses(const Fortran::parser::OmpClauseList *value) {
+ OpWithBodyGenInfo &setClauses(const List<Clause> *value) {
clauses = value;
return *this;
}
@@ -848,7 +847,7 @@ struct OpWithBodyGenInfo {
/// [in] is this an outer operation - prevents privatization.
bool outerCombined = false;
/// [in] list of clauses to process.
- const Fortran::parser::OmpClauseList *clauses = nullptr;
+ const List<Clause> *clauses = nullptr;
/// [in] if provided, processes the construct's data-sharing attributes.
DataSharingProcessor *dsp = nullptr;
/// [in] if provided, list of reduction symbols
@@ -1226,36 +1225,33 @@ static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) {
// Code generation functions for clauses
//===----------------------------------------------------------------------===//
-static void genCriticalDeclareClauses(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- mlir::omp::CriticalClauseOps &clauseOps, llvm::StringRef name) {
+static void
+genCriticalDeclareClauses(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const List<Clause> &clauses, mlir::Location loc,
+ mlir::omp::CriticalClauseOps &clauseOps,
+ llvm::StringRef name) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processHint(clauseOps);
clauseOps.nameAttr =
mlir::StringAttr::get(converter.getFirOpBuilder().getContext(), name);
}
-static void genFlushClauses(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- const std::optional<Fortran::parser::OmpObjectList> &objects,
- const std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>
- &clauses,
- mlir::Location loc, llvm::SmallVectorImpl<mlir::Value> &operandRange) {
- if (objects)
- genObjectList2(*objects, converter, operandRange);
-
- if (clauses && clauses->size() > 0)
+static void genFlushClauses(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const ObjectList &objects,
+ const List<Clause> &clauses, mlir::Location loc,
+ llvm::SmallVectorImpl<mlir::Value> &operandRange) {
+ genObjectList(objects, converter, operandRange);
+
+ if (clauses.size() > 0)
TODO(converter.getCurrentLocation(), "Handle OmpMemoryOrderClause");
}
static void
genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::OrderedRegionClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processTODO<clause::Simd>(loc, llvm::omp::Directive::OMPD_ordered);
@@ -1264,9 +1260,9 @@ genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter,
static void genParallelClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- bool processReduction, mlir::omp::ParallelClauseOps &clauseOps,
+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, bool processReduction,
+ mlir::omp::ParallelClauseOps &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
@@ -1286,8 +1282,7 @@ static void genParallelClauses(
static void genSectionsClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
bool clausesFromBeginSections,
mlir::omp::SectionsClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
@@ -1304,9 +1299,8 @@ static void genSimdLoopClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::StatementContext &stmtCtx,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- mlir::omp::SimdLoopClauseOps &clauseOps,
+ Fortran::lower::pft::Evaluation &eval, const List<Clause> &clauses,
+ mlir::Location loc, mlir::omp::SimdLoopClauseOps &clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processCollapse(loc, eval, clauseOps, iv);
@@ -1324,9 +1318,8 @@ static void genSimdLoopClauses(
static void genSingleClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &beginClauses,
- const Fortran::parser::OmpClauseList &endClauses,
- mlir::Location loc,
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc,
mlir::omp::SingleClauseOps &clauseOps) {
ClauseProcessor bcp(converter, semaCtx, beginClauses);
bcp.processAllocate(clauseOps);
@@ -1340,9 +1333,8 @@ static void genSingleClauses(Fortran::lower::AbstractConverter &converter,
static void genTargetClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- bool processHostOnlyClauses, bool processReduction,
+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, bool processHostOnlyClauses, bool processReduction,
mlir::omp::TargetClauseOps &clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &mapSyms,
llvm::SmallVectorImpl<mlir::Location> &mapSymLocs,
@@ -1368,9 +1360,8 @@ static void genTargetClauses(
static void genTargetDataClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- mlir::omp::TargetDataClauseOps &clauseOps,
+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, mlir::omp::TargetDataClauseOps &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms) {
@@ -1401,9 +1392,8 @@ static void genTargetDataClauses(
static void genTargetEnterExitUpdateDataClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- llvm::omp::Directive directive,
+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, llvm::omp::Directive directive,
mlir::omp::TargetEnterExitUpdateDataClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processDepend(clauseOps);
@@ -1422,8 +1412,7 @@ static void genTargetEnterExitUpdateDataClauses(
static void genTaskClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
@@ -1442,8 +1431,7 @@ static void genTaskClauses(Fortran::lower::AbstractConverter &converter,
static void genTaskgroupClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskgroupClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
@@ -1453,8 +1441,7 @@ static void genTaskgroupClauses(Fortran::lower::AbstractConverter &converter,
static void genTaskwaitClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskwaitClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processTODO<clause::Depend, clause::Nowait>(
@@ -1464,8 +1451,7 @@ static void genTaskwaitClauses(Fortran::lower::AbstractConverter &converter,
static void genTeamsClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TeamsClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
@@ -1482,9 +1468,8 @@ static void genWsloopClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::StatementContext &stmtCtx,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauses,
- const Fortran::parser::OmpClauseList *endClauses, mlir::Location loc,
+ Fortran::lower::pft::Evaluation &eval, const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc,
mlir::omp::WsloopClauseOps &clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
@@ -1501,8 +1486,8 @@ static void genWsloopClauses(
if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
- if (endClauses) {
- ClauseProcessor ecp(converter, semaCtx, *endClauses);
+ if (!endClauses.empty()) {
+ ClauseProcessor ecp(converter, semaCtx, endClauses);
ecp.processNowait(clauseOps);
}
@@ -1525,8 +1510,7 @@ static mlir::omp::CriticalOp
genCriticalOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList,
+ mlir::Location loc, const List<Clause> &clauses,
const std::optional<Fortran::parser::Name> &name) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::FlatSymbolRefAttr nameAttr;
@@ -1537,7 +1521,7 @@ genCriticalOp(Fortran::lower::AbstractConverter &converter,
auto global = mod.lookupSymbol<mlir::omp::CriticalDeclareOp>(nameStr);
if (!global) {
mlir::omp::CriticalClauseOps clauseOps;
- genCriticalDeclareClauses(converter, semaCtx, clauseList, loc, clauseOps,
+ genCriticalDeclareClauses(converter, semaCtx, clauses, loc, clauseOps,
nameStr);
mlir::OpBuilder modBuilder(mod.getBodyRegion());
@@ -1556,8 +1540,7 @@ static mlir::omp::DistributeOp
genDistributeOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
TODO(loc, "Distribute construct");
return nullptr;
}
@@ -1566,12 +1549,9 @@ static mlir::omp::FlushOp
genFlushOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const std::optional<Fortran::parser::OmpObjectList> &objectList,
- const std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>
- &clauseList) {
+ const ObjectList &objects, const List<Clause> &clauses) {
llvm::SmallVector<mlir::Value> operandRange;
- genFlushClauses(converter, semaCtx, objectList, clauseList, loc,
- operandRange);
+ genFlushClauses(converter, semaCtx, objects, clauses, loc, operandRange);
return converter.getFirOpBuilder().create<mlir::omp::FlushOp>(
converter.getCurrentLocation(), operandRange);
@@ -1591,7 +1571,7 @@ static mlir::omp::OrderedOp
genOrderedOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ const List<Clause> &clauses) {
TODO(loc, "OMPD_ordered");
return nullptr;
}
@@ -1600,10 +1580,9 @@ static mlir::omp::OrderedRegionOp
genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
mlir::omp::OrderedRegionClauseOps clauseOps;
- genOrderedRegionClauses(converter, semaCtx, clauseList, loc, clauseOps);
+ genOrderedRegionClauses(converter, semaCtx, clauses, loc, clauseOps);
return genOpWithBody<mlir::omp::OrderedRegionOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(genNested),
@@ -1615,8 +1594,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList,
+ mlir::Location loc, const List<Clause> &clauses,
bool outerCombined = false) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
@@ -1624,7 +1602,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> privateSyms;
llvm::SmallVector<mlir::Type> reductionTypes;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
- genParallelClauses(converter, semaCtx, stmtCtx, clauseList, loc,
+ genParallelClauses(converter, semaCtx, stmtCtx, clauses, loc,
/*processReduction=*/!outerCombined, clauseOps,
reductionTypes, reductionSyms);
@@ -1637,7 +1615,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
.setOuterCombined(outerCombined)
- .setClauses(&clauseList)
+ .setClauses(&clauses)
.setReductions(&reductionSyms, &reductionTypes)
.setGenRegionEntryCb(reductionCallback);
@@ -1645,7 +1623,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
bool privatize = !outerCombined;
- DataSharingProcessor dsp(converter, semaCtx, clauseList, eval,
+ DataSharingProcessor dsp(converter, semaCtx, clauses, eval,
/*useDelayedPrivatization=*/true, &symTable);
if (privatize)
@@ -1692,14 +1670,13 @@ static mlir::omp::SectionOp
genSectionOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
// Currently only private/firstprivate clause is handled, and
// all privatization is done within `omp.section` operations.
return genOpWithBody<mlir::omp::SectionOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
- .setClauses(&clauseList));
+ .setClauses(&clauses));
}
static mlir::omp::SectionsOp
@@ -1716,18 +1693,17 @@ static mlir::omp::SimdLoopOp
genSimdLoopOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
- DataSharingProcessor dsp(converter, semaCtx, clauseList, eval);
+ const List<Clause> &clauses) {
+ DataSharingProcessor dsp(converter, semaCtx, clauses, eval);
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
mlir::omp::SimdLoopClauseOps clauseOps;
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
- genSimdLoopClauses(converter, semaCtx, stmtCtx, eval, clauseList, loc,
- clauseOps, iv);
+ genSimdLoopClauses(converter, semaCtx, stmtCtx, eval, clauses, loc, clauseOps,
+ iv);
- auto *nestedEval =
- getCollapsedLoopEval(eval, Fortran::lower::getCollapseValue(clauseList));
+ auto *nestedEval = getCollapsedLoopEval(eval, getCollapseValue(clauses));
auto ivCallback = [&](mlir::Operation *op) {
return genLoopVars(op, converter, loc, iv);
@@ -1735,7 +1711,7 @@ genSimdLoopOp(Fortran::lower::AbstractConverter &converter,
return genOpWithBody<mlir::omp::SimdLoopOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
- .setClauses(&clauseList)
+ .setClauses(&clauses)
.setDataSharingProcessor(&dsp)
.setGenRegionEntryCb(ivCallback),
clauseOps);
@@ -1745,17 +1721,16 @@ static mlir::omp::SingleOp
genSingleOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList &endClauseList) {
+ mlir::Location loc, const List<Clause> &beginClauses,
+ const List<Clause> &endClauses) {
mlir::omp::SingleClauseOps clauseOps;
- genSingleClauses(converter, semaCtx, beginClauseList, endClauseList, loc,
+ genSingleClauses(converter, semaCtx, beginClauses, endClauses, loc,
clauseOps);
return genOpWithBody<mlir::omp::SingleOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
- .setClauses(&beginClauseList),
+ .setClauses(&beginClauses),
clauseOps);
}
@@ -1763,8 +1738,7 @@ static mlir::omp::TargetOp
genTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList,
+ mlir::Location loc, const List<Clause> &clauses,
bool outerCombined = false) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
@@ -1777,7 +1751,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSyms;
llvm::SmallVector<mlir::Location> mapSymLocs;
llvm::SmallVector<mlir::Type> mapSymTypes;
- genTargetClauses(converter, semaCtx, stmtCtx, clauseList, loc,
+ genTargetClauses(converter, semaCtx, stmtCtx, clauses, loc,
processHostOnlyClauses, /*processReduction=*/outerCombined,
clauseOps, mapSyms, mapSymLocs, mapSymTypes);
@@ -1875,14 +1849,13 @@ static mlir::omp::TargetDataOp
genTargetDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
Fortran::lower::StatementContext stmtCtx;
mlir::omp::TargetDataClauseOps clauseOps;
llvm::SmallVector<mlir::Type> useDeviceTypes;
llvm::SmallVector<mlir::Location> useDeviceLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSyms;
- genTargetDataClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps,
+ genTargetDataClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps,
useDeviceTypes, useDeviceLocs, useDeviceSyms);
auto targetDataOp =
@@ -1894,11 +1867,11 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter,
return targetDataOp;
}
-template <typename OpTy>
-static OpTy genTargetEnterExitUpdateDataOp(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx, mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+template <typename OpTy> static OpTy
+genTargetEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ mlir::Location loc,
+ const List<Clause> &clauses) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
@@ -1915,8 +1888,8 @@ static OpTy genTargetEnterExitUpdateDataOp(
}
mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps;
- genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauseList,
- loc, directive, clauseOps);
+ genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauses, loc,
+ directive, clauseOps);
return firOpBuilder.create<OpTy>(loc, clauseOps);
}
@@ -1925,16 +1898,15 @@ static mlir::omp::TaskOp
genTaskOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
Fortran::lower::StatementContext stmtCtx;
mlir::omp::TaskClauseOps clauseOps;
- genTaskClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps);
+ genTaskClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps);
return genOpWithBody<mlir::omp::TaskOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
- .setClauses(&clauseList),
+ .setClauses(&clauses),
clauseOps);
}
@@ -1942,15 +1914,14 @@ static mlir::omp::TaskgroupOp
genTaskgroupOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
mlir::omp::TaskgroupClauseOps clauseOps;
- genTaskgroupClauses(converter, semaCtx, clauseList, loc, clauseOps);
+ genTaskgroupClauses(converter, semaCtx, clauses, loc, clauseOps);
return genOpWithBody<mlir::omp::TaskgroupOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
- .setClauses(&clauseList),
+ .setClauses(&clauses),
clauseOps);
}
@@ -1958,7 +1929,7 @@ static mlir::omp::TaskloopOp
genTaskloopOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ const List<Clause> &clauses) {
TODO(loc, "Taskloop construct");
}
@@ -1966,9 +1937,9 @@ static mlir::omp::TaskwaitOp
genTaskwaitOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ const List<Clause> &clauses) {
mlir::omp::TaskwaitClauseOps clauseOps;
- genTaskwaitClauses(converter, semaCtx, clauseList, loc, clauseOps);
+ genTaskwaitClauses(converter, semaCtx, clauses, loc, clauseOps);
return converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(loc,
clauseOps);
}
@@ -1984,17 +1955,17 @@ static mlir::omp::TeamsOp
genTeamsOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc, const Fortran::parser::OmpClauseList &clauseList,
+ mlir::Location loc, const List<Clause> &clauses,
bool outerCombined = false) {
Fortran::lower::StatementContext stmtCtx;
mlir::omp::TeamsClauseOps clauseOps;
- genTeamsClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps);
+ genTeamsClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps);
return genOpWithBody<mlir::omp::TeamsOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
.setOuterCombined(outerCombined)
- .setClauses(&clauseList),
+ .setClauses(&clauses),
clauseOps);
}
@@ -2002,9 +1973,8 @@ static mlir::omp::WsloopOp
genWsloopOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList) {
- DataSharingProcessor dsp(converter, semaCtx, beginClauseList, eval);
+ const List<Clause> &beginClauses, const List<Clause> &endClauses) {
+ DataSharingProcessor dsp(converter, semaCtx, beginClauses, eval);
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
@@ -2012,12 +1982,10 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
llvm::SmallVector<mlir::Type> reductionTypes;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
- genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauseList,
- endClauseList, loc, clauseOps, iv, reductionTypes,
- reductionSyms);
+ genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauses, endClauses,
+ loc, clauseOps, iv, reductionTypes, reductionSyms);
- auto *nestedEval = getCollapsedLoopEval(
- eval, Fortran::lower::getCollapseValue(beginClauseList));
+ auto *nestedEval = getCollapsedLoopEval(eval, getCollapseValue(beginClauses));
auto ivCallback = [&](mlir::Operation *op) {
return genLoopAndReductionVars(op, converter, loc, iv, reductionSyms,
@@ -2026,7 +1994,7 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
return genOpWithBody<mlir::omp::WsloopOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
- .setClauses(&beginClauseList)
+ .setClauses(&beginClauses)
.setDataSharingProcessor(&dsp)
.setReductions(&reductionSyms, &reductionTypes)
.setGenRegionEntryCb(ivCallback),
@@ -2041,8 +2009,8 @@ static void genCompositeDistributeParallelDo(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc) {
TODO(loc, "Composite DISTRIBUTE PARALLEL DO");
}
@@ -2050,8 +2018,8 @@ static void genCompositeDistributeParallelDoSimd(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc) {
TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD");
}
@@ -2059,8 +2027,8 @@ static void genCompositeDistributeSimd(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc) {
TODO(loc, "Composite DISTRIBUTE SIMD");
}
@@ -2068,10 +2036,10 @@ static void
genCompositeDoSimd(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList,
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses,
mlir::Location loc) {
- ClauseProcessor cp(converter, semaCtx, beginClauseList);
+ ClauseProcessor cp(converter, semaCtx, beginClauses);
cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
clause::Order, clause::Safelen, clause::Simdlen>(
loc, llvm::omp::OMPD_do_simd);
@@ -2083,15 +2051,15 @@ genCompositeDoSimd(Fortran::lower::AbstractConverter &converter,
// When support for vectorization is enabled, then we need to add handling of
// if clause. Currently if clause can be skipped because we always assume
// SIMD length = 1.
- genWsloopOp(converter, semaCtx, eval, loc, beginClauseList, endClauseList);
+ genWsloopOp(converter, semaCtx, eval, loc, beginClauses, endClauses);
}
static void
genCompositeTaskloopSimd(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList,
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses,
mlir::Location loc) {
TODO(loc, "Composite TASKLOOP SIMD");
}
@@ -2201,8 +2169,9 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
const auto &directive =
std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
simpleStandaloneConstruct.t);
- const auto &clauseList =
- std::get<Fortran::parser::OmpClauseList>(simpleStandaloneConstruct.t);
+ List<Clause> clauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(simpleStandaloneConstruct.t),
+ semaCtx);
mlir::Location currentLocation = converter.genLocation(directive.source);
switch (directive.v) {
@@ -2212,29 +2181,29 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
genBarrierOp(converter, semaCtx, eval, currentLocation);
break;
case llvm::omp::Directive::OMPD_taskwait:
- genTaskwaitOp(converter, semaCtx, eval, currentLocation, clauseList);
+ genTaskwaitOp(converter, semaCtx, eval, currentLocation, clauses);
break;
case llvm::omp::Directive::OMPD_taskyield:
genTaskyieldOp(converter, semaCtx, eval, currentLocation);
break;
case llvm::omp::Directive::OMPD_target_data:
genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true,
- currentLocation, clauseList);
+ currentLocation, clauses);
break;
case llvm::omp::Directive::OMPD_target_enter_data:
genTargetEnterExitUpdateDataOp<mlir::omp::TargetEnterDataOp>(
- converter, semaCtx, currentLocation, clauseList);
+ converter, semaCtx, currentLocation, clauses);
break;
case llvm::omp::Directive::OMPD_target_exit_data:
genTargetEnterExitUpdateDataOp<mlir::omp::TargetExitDataOp>(
- converter, semaCtx, currentLocation, clauseList);
+ converter, semaCtx, currentLocation, clauses);
break;
case llvm::omp::Directive::OMPD_target_update:
genTargetEnterExitUpdateDataOp<mlir::omp::TargetUpdateOp>(
- converter, semaCtx, currentLocation, clauseList);
+ converter, semaCtx, currentLocation, clauses);
break;
case llvm::omp::Directive::OMPD_ordered:
- genOrderedOp(converter, semaCtx, eval, currentLocation, clauseList);
+ genOrderedOp(converter, semaCtx, eval, currentLocation, clauses);
break;
}
}
@@ -2251,8 +2220,14 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const auto &clauseList =
std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
flushConstruct.t);
+ ObjectList objects =
+ objectList ? makeObjects(*objectList, semaCtx) : ObjectList{};
+ List<Clause> clauses =
+ clauseList ? makeList(*clauseList,
+ [&](auto &&s) { return makeClause(s.v, semaCtx); })
+ : List<Clause>{};
mlir::Location currentLocation = converter.genLocation(verbatim.source);
- genFlushOp(converter, semaCtx, eval, currentLocation, objectList, clauseList);
+ genFlushOp(converter, semaCtx, eval, currentLocation, objects, clauses);
}
static void
@@ -2357,44 +2332,44 @@ genOMP(Fortran::lower::AbstractConverter &converter,
converter.genLocation(beginBlockDirective.source);
const auto origDirective =
std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t).v;
- const auto &beginClauseList =
- std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
- const auto &endClauseList =
- std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t);
+ List<Clause> beginClauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t), semaCtx);
+ List<Clause> endClauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t), semaCtx);
assert(llvm::omp::blockConstructSet.test(origDirective) &&
"Expected block construct");
- for (const Fortran::parser::OmpClause &clause : beginClauseList.v) {
+ for (const Clause &clause : beginClauses) {
mlir::Location clauseLocation = converter.genLocation(clause.source);
- if (!std::get_if<Fortran::parser::OmpClause::If>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::NumThreads>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::ProcBind>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Allocate>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Default>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Final>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Priority>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Depend>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Firstprivate>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Simd>(&clause.u)) {
+ if (!std::get_if<clause::If>(&clause.u) &&
+ !std::get_if<clause::NumThreads>(&clause.u) &&
+ !std::get_if<clause::ProcBind>(&clause.u) &&
+ !std::get_if<clause::Allocate>(&clause.u) &&
+ !std::get_if<clause::Default>(&clause.u) &&
+ !std::get_if<clause::Final>(&clause.u) &&
+ !std::get_if<clause::Priority>(&clause.u) &&
+ !std::get_if<clause::Reduction>(&clause.u) &&
+ !std::get_if<clause::Depend>(&clause.u) &&
+ !std::get_if<clause::Private>(&clause.u) &&
+ !std::get_if<clause::Firstprivate>(&clause.u) &&
+ !std::get_if<clause::Copyin>(&clause.u) &&
+ !std::get_if<clause::Shared>(&clause.u) &&
+ !std::get_if<clause::Threads>(&clause.u) &&
+ !std::get_if<clause::Map>(&clause.u) &&
+ !std::get_if<clause::UseDevicePtr>(&clause.u) &&
+ !std::get_if<clause::UseDeviceAddr>(&clause.u) &&
+ !std::get_if<clause::ThreadLimit>(&clause.u) &&
+ !std::get_if<clause::NumTeams>(&clause.u) &&
+ !std::get_if<clause::Simd>(&clause.u)) {
TODO(clauseLocation, "OpenMP Block construct clause");
}
}
- for (const auto &clause : endClauseList.v) {
+ for (const Clause &clause : endClauses) {
mlir::Location clauseLocation = converter.genLocation(clause.source);
- if (!std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Copyprivate>(&clause.u))
+ if (!std::get_if<clause::Nowait>(&clause.u) &&
+ !std::get_if<clause::Copyprivate>(&clause.u))
TODO(clauseLocation, "OpenMP Block construct clause");
}
@@ -2413,44 +2388,44 @@ genOMP(Fortran::lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_ordered:
// 2.17.9 ORDERED construct.
genOrderedRegionOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_parallel:
// 2.6 PARALLEL construct.
genParallelOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, beginClauseList, outerCombined);
+ currentLocation, beginClauses, outerCombined);
break;
case llvm::omp::Directive::OMPD_single:
// 2.8.2 SINGLE construct.
genSingleOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList, endClauseList);
+ beginClauses, endClauses);
break;
case llvm::omp::Directive::OMPD_target:
// 2.12.5 TARGET construct.
genTargetOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList, outerCombined);
+ beginClauses, outerCombined);
break;
case llvm::omp::Directive::OMPD_target_data:
// 2.12.2 TARGET DATA construct.
genTargetDataOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_task:
// 2.10.1 TASK construct.
genTaskOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_taskgroup:
// 2.17.6 TASKGROUP construct.
genTaskgroupOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_teams:
// 2.7 TEAMS construct.
// FIXME Pass the outerCombined argument or rename it to better describe
// what it represents if it must always be `false` in this context.
genTeamsOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_workshare:
// 2.8.3 WORKSHARE construct.
@@ -2458,7 +2433,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
// implementation for this feature will come later. For the codes
// that use this construct, add a single construct for now.
genSingleOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList, endClauseList);
+ beginClauses, endClauses);
break;
default:
llvm_unreachable("Unexpected block construct");
@@ -2476,11 +2451,12 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
const auto &cd =
std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t);
- const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
+ List<Clause> clauses =
+ makeClauses(std::get<Fortran::parser::OmpClauseList>(cd.t), semaCtx);
const auto &name = std::get<std::optional<Fortran::parser::Name>>(cd.t);
mlir::Location currentLocation = converter.getCurrentLocation();
genCriticalOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
- clauseList, name);
+ clauses, name);
}
static void
@@ -2499,8 +2475,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
const auto &beginLoopDirective =
std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t);
- const auto &beginClauseList =
- std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t);
+ List<Clause> beginClauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t), semaCtx);
mlir::Location currentLocation =
converter.genLocation(beginLoopDirective.source);
const auto origDirective =
@@ -2509,15 +2485,15 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
assert(llvm::omp::loopConstructSet.test(origDirective) &&
"Expected loop construct");
- const auto *endClauseList = [&]() {
- using RetTy = const Fortran::parser::OmpClauseList *;
+ List<Clause> endClauses = [&]() {
if (auto &endLoopDirective =
std::get<std::optional<Fortran::parser::OmpEndLoopDirective>>(
loopConstruct.t)) {
- return RetTy(
- &std::get<Fortran::parser::OmpClauseList>((*endLoopDirective).t));
+ return makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(endLoopDirective->t),
+ semaCtx);
}
- return RetTy();
+ return List<Clause>{};
}();
std::optional<llvm::omp::Directive> nextDir = origDirective;
@@ -2530,29 +2506,29 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_distribute_parallel_do:
// 2.9.4.3 DISTRIBUTE PARALLEL Worksharing-Loop construct.
genCompositeDistributeParallelDo(converter, semaCtx, eval,
- beginClauseList, endClauseList,
+ beginClauses, endClauses,
currentLocation);
break;
case llvm::omp::Directive::OMPD_distribute_parallel_do_simd:
// 2.9.4.4 DISTRIBUTE PARALLEL Worksharing-Loop SIMD construct.
genCompositeDistributeParallelDoSimd(converter, semaCtx, eval,
- beginClauseList, endClauseList,
+ beginClauses, endClauses,
currentLocation);
break;
case llvm::omp::Directive::OMPD_distribute_simd:
// 2.9.4.2 DISTRIBUTE SIMD construct.
- genCompositeDistributeSimd(converter, semaCtx, eval, beginClauseList,
- endClauseList, currentLocation);
+ genCompositeDistributeSimd(converter, semaCtx, eval, beginClauses,
+ endClauses, currentLocation);
break;
case llvm::omp::Directive::OMPD_do_simd:
// 2.9.3.2 Worksharing-Loop SIMD construct.
- genCompositeDoSimd(converter, semaCtx, eval, beginClauseList,
- endClauseList, currentLocation);
+ genCompositeDoSimd(converter, semaCtx, eval, beginClauses,
+ endClauses, currentLocation);
break;
case llvm::omp::Directive::OMPD_taskloop_simd:
// 2.10.3 TASKLOOP SIMD construct.
- genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauseList,
- endClauseList, currentLocation);
+ genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauses,
+ endClauses, currentLocation);
break;
default:
llvm_unreachable("Unexpected composite construct");
@@ -2563,12 +2539,12 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_distribute:
// 2.9.4.1 DISTRIBUTE construct.
genDistributeOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_do:
// 2.9.2 Worksharing-Loop construct.
- genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList,
- endClauseList);
+ genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauses,
+ endClauses);
break;
case llvm::omp::Directive::OMPD_parallel:
// 2.6 PARALLEL construct.
@@ -2577,24 +2553,24 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
// Maybe rename the argument if it represents something else or
// initialize it properly.
genParallelOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, beginClauseList,
+ currentLocation, beginClauses,
/*outerCombined=*/true);
break;
case llvm::omp::Directive::OMPD_simd:
// 2.9.3.1 SIMD construct.
genSimdLoopOp(converter, semaCtx, eval, currentLocation,
- beginClauseList);
- genOpenMPReduction(converter, semaCtx, beginClauseList);
+ beginClauses);
+ genOpenMPReduction(converter, semaCtx, beginClauses);
break;
case llvm::omp::Directive::OMPD_target:
// 2.12.5 TARGET construct.
genTargetOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList, /*outerCombined=*/true);
+ beginClauses, /*outerCombined=*/true);
break;
case llvm::omp::Directive::OMPD_taskloop:
// 2.10.2 TASKLOOP construct.
genTaskloopOp(converter, semaCtx, eval, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_teams:
// 2.7 TEAMS construct.
@@ -2603,7 +2579,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
// Maybe rename the argument if it represents something else or
// initialize it properly.
genTeamsOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList, /*outerCombined=*/true);
+ beginClauses, /*outerCombined=*/true);
break;
case llvm::omp::Directive::OMPD_loop:
case llvm::omp::Directive::OMPD_masked:
@@ -2639,14 +2615,15 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) {
const auto &beginSectionsDirective =
std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t);
- const auto &beginClauseList =
- std::get<Fortran::parser::OmpClauseList>(beginSectionsDirective.t);
+ List<Clause> beginClauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(beginSectionsDirective.t),
+ semaCtx);
// Process clauses before optional omp.parallel, so that new variables are
// allocated outside of the parallel region
mlir::Location currentLocation = converter.getCurrentLocation();
mlir::omp::SectionsClauseOps clauseOps;
- genSectionsClauses(converter, semaCtx, beginClauseList, currentLocation,
+ genSectionsClauses(converter, semaCtx, beginClauses, currentLocation,
/*clausesFromBeginSections=*/true, clauseOps);
// Parallel wrapper of PARALLEL SECTIONS construct
@@ -2655,14 +2632,15 @@ genOMP(Fortran::lower::AbstractConverter &converter,
.v;
if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
genParallelOp(converter, symTable, semaCtx, eval,
- /*genNested=*/false, currentLocation, beginClauseList,
+ /*genNested=*/false, currentLocation, beginClauses,
/*outerCombined=*/true);
} else {
const auto &endSectionsDirective =
std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t);
- const auto &endClauseList =
- std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t);
- genSectionsClauses(converter, semaCtx, endClauseList, currentLocation,
+ List<Clause> endClauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t),
+ semaCtx);
+ genSectionsClauses(converter, semaCtx, endClauses, currentLocation,
/*clausesFromBeginSections=*/false, clauseOps);
}
@@ -2678,7 +2656,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
llvm::zip(sectionBlocks.v, eval.getNestedEvaluations())) {
symTable.pushScope();
genSectionOp(converter, semaCtx, neval, /*genNested=*/true, currentLocation,
- beginClauseList);
+ beginClauses);
symTable.popScope();
firOpBuilder.restoreInsertionPoint(ip);
}
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index b9c0660aa4da8e..da3f2be73e5095 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -36,6 +36,17 @@ namespace Fortran {
namespace lower {
namespace omp {
+int64_t getCollapseValue(const List<Clause> &clauses) {
+ auto iter = llvm::find_if(clauses, [](const Clause &clause) {
+ return clause.id == llvm::omp::Clause::OMPC_collapse;
+ });
+ if (iter != clauses.end()) {
+ const auto &collapse = std::get<clause::Collapse>(iter->u);
+ return evaluate::ToInt64(collapse.v).value();
+ }
+ return 1;
+}
+
void genObjectList(const ObjectList &objects,
Fortran::lower::AbstractConverter &converter,
llvm::SmallVectorImpl<mlir::Value> &operands) {
@@ -52,25 +63,6 @@ void genObjectList(const ObjectList &objects,
}
}
-void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
- Fortran::lower::AbstractConverter &converter,
- llvm::SmallVectorImpl<mlir::Value> &operands) {
- auto addOperands = [&](Fortran::lower::SymbolRef sym) {
- const mlir::Value variable = converter.getSymbolAddress(sym);
- if (variable) {
- operands.push_back(variable);
- } else if (const auto *details =
- sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
- operands.push_back(converter.getSymbolAddress(details->symbol()));
- converter.copySymbolBinding(details->symbol(), sym);
- }
- };
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
- addOperands(*sym);
- }
-}
-
mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
std::size_t loopVarTypeSize) {
// OpenMP runtime requires 32-bit or 64-bit loop variables.
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 4074bf73987d5b..b3a9f7f30c98bd 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -58,6 +58,8 @@ void gatherFuncAndVarSyms(
const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause);
+int64_t getCollapseValue(const List<Clause> &clauses);
+
Fortran::semantics::Symbol *
getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject);
@@ -65,10 +67,6 @@ void genObjectList(const ObjectList &objects,
Fortran::lower::AbstractConverter &converter,
llvm::SmallVectorImpl<mlir::Value> &operands);
-void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
- Fortran::lower::AbstractConverter &converter,
- llvm::SmallVectorImpl<mlir::Value> &operands);
-
} // namespace omp
} // namespace lower
} // namespace Fortran
>From 291dc48d5e0b7e0ee39681a1276bd1d63f456b01 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Mon, 1 Apr 2024 10:07:45 -0500
Subject: [PATCH 06/17] [Frontend][OpenMP] Refactor getLeafConstructs, add
getCompoundConstruct
Emit a special leaf constuct table in DirectiveEmitter.cpp, which will
allow both decomposition of a construct into leafs, and composition of
constituent constructs into a single compound construct (is possible).
---
llvm/include/llvm/Frontend/OpenMP/OMP.h | 7 +
llvm/lib/Frontend/OpenMP/OMP.cpp | 64 +++++-
llvm/test/TableGen/directive1.td | 19 +-
llvm/test/TableGen/directive2.td | 19 +-
llvm/unittests/Frontend/CMakeLists.txt | 1 +
llvm/unittests/Frontend/OpenMPComposeTest.cpp | 41 ++++
llvm/utils/TableGen/DirectiveEmitter.cpp | 194 +++++++++++-------
7 files changed, 258 insertions(+), 87 deletions(-)
create mode 100644 llvm/unittests/Frontend/OpenMPComposeTest.cpp
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.h b/llvm/include/llvm/Frontend/OpenMP/OMP.h
index a85cd9d344c6d7..4ed47f15dfe59e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.h
@@ -15,4 +15,11 @@
#include "llvm/Frontend/OpenMP/OMP.h.inc"
+#include "llvm/ADT/ArrayRef.h"
+
+namespace llvm::omp {
+ArrayRef<Directive> getLeafConstructs(Directive D);
+Directive getCompoundConstruct(ArrayRef<Directive> Parts);
+} // namespace llvm::omp
+
#endif // LLVM_FRONTEND_OPENMP_OMP_H
diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp
index 4f2f95392648b3..dd99d3d074fd1e 100644
--- a/llvm/lib/Frontend/OpenMP/OMP.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMP.cpp
@@ -8,12 +8,74 @@
#include "llvm/Frontend/OpenMP/OMP.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/ErrorHandling.h"
+#include <algorithm>
+#include <iterator>
+#include <type_traits>
+
using namespace llvm;
-using namespace omp;
+using namespace llvm::omp;
#define GEN_DIRECTIVES_IMPL
#include "llvm/Frontend/OpenMP/OMP.inc"
+
+namespace llvm::omp {
+ArrayRef<Directive> getLeafConstructs(Directive D) {
+ auto Idx = static_cast<int>(D);
+ if (Idx < 0 || Idx >= static_cast<int>(Directive_enumSize))
+ return {};
+ const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]];
+ return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1]));
+}
+
+Directive getCompoundConstruct(ArrayRef<Directive> Parts) {
+ if (Parts.empty())
+ return OMPD_unknown;
+
+ // Parts don't have to be leafs, so expand them into leafs first.
+ // Store the expanded leafs in the same format as rows in the leaf
+ // table (generated by tablegen).
+ SmallVector<Directive> RawLeafs(2);
+ for (Directive P : Parts) {
+ ArrayRef<Directive> Ls = getLeafConstructs(P);
+ if (!Ls.empty())
+ RawLeafs.append(Ls.begin(), Ls.end());
+ else
+ RawLeafs.push_back(P);
+ }
+
+ auto GivenLeafs{ArrayRef<Directive>(RawLeafs).drop_front(2)};
+ if (GivenLeafs.size() == 1)
+ return GivenLeafs.front();
+ RawLeafs[1] = static_cast<Directive>(GivenLeafs.size());
+
+ auto Iter = llvm::lower_bound(
+ LeafConstructTable,
+ static_cast<std::decay_t<decltype(*LeafConstructTable)>>(RawLeafs.data()),
+ [](const auto *RowA, const auto *RowB) {
+ const auto *BeginA = &RowA[2];
+ const auto *EndA = BeginA + static_cast<int>(RowA[1]);
+ const auto *BeginB = &RowB[2];
+ const auto *EndB = BeginB + static_cast<int>(RowB[1]);
+ if (BeginA == EndA && BeginB == EndB)
+ return static_cast<int>(RowA[0]) < static_cast<int>(RowB[0]);
+ return std::lexicographical_compare(BeginA, EndA, BeginB, EndB);
+ });
+
+ if (Iter == std::end(LeafConstructTable))
+ return OMPD_unknown;
+
+ // Verify that we got a match.
+ Directive Found = (*Iter)[0];
+ ArrayRef<Directive> FoundLeafs = getLeafConstructs(Found);
+ if (FoundLeafs == GivenLeafs)
+ return Found;
+ return OMPD_unknown;
+}
+} // namespace llvm::omp
diff --git a/llvm/test/TableGen/directive1.td b/llvm/test/TableGen/directive1.td
index 3184f625ead928..e6150210e7e9a4 100644
--- a/llvm/test/TableGen/directive1.td
+++ b/llvm/test/TableGen/directive1.td
@@ -52,6 +52,7 @@ def TDL_DirA : Directive<"dira"> {
// CHECK-EMPTY:
// CHECK-NEXT: #include "llvm/ADT/ArrayRef.h"
// CHECK-NEXT: #include "llvm/ADT/BitmaskEnum.h"
+// CHECK-NEXT: #include <cstddef>
// CHECK-EMPTY:
// CHECK-NEXT: namespace llvm {
// CHECK-NEXT: class StringRef;
@@ -112,7 +113,7 @@ def TDL_DirA : Directive<"dira"> {
// CHECK-NEXT: /// Return true if \p C is a valid clause for \p D in version \p Version.
// CHECK-NEXT: bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version);
// CHECK-EMPTY:
-// CHECK-NEXT: llvm::ArrayRef<Directive> getLeafConstructs(Directive D);
+// CHECK-NEXT: constexpr std::size_t getMaxLeafCount() { return 0; }
// CHECK-NEXT: Association getDirectiveAssociation(Directive D);
// CHECK-NEXT: AKind getAKind(StringRef);
// CHECK-NEXT: llvm::StringRef getTdlAKindName(AKind);
@@ -359,13 +360,6 @@ def TDL_DirA : Directive<"dira"> {
// IMPL-NEXT: llvm_unreachable("Invalid Tdl Directive kind");
// IMPL-NEXT: }
// IMPL-EMPTY:
-// IMPL-NEXT: llvm::ArrayRef<llvm::tdl::Directive> llvm::tdl::getLeafConstructs(llvm::tdl::Directive Dir) {
-// IMPL-NEXT: switch (Dir) {
-// IMPL-NEXT: default:
-// IMPL-NEXT: return ArrayRef<llvm::tdl::Directive>{};
-// IMPL-NEXT: } // switch (Dir)
-// IMPL-NEXT: }
-// IMPL-EMPTY:
// IMPL-NEXT: llvm::tdl::Association llvm::tdl::getDirectiveAssociation(llvm::tdl::Directive Dir) {
// IMPL-NEXT: switch (Dir) {
// IMPL-NEXT: case llvm::tdl::Directive::TDLD_dira:
@@ -374,4 +368,13 @@ def TDL_DirA : Directive<"dira"> {
// IMPL-NEXT: llvm_unreachable("Unexpected directive");
// IMPL-NEXT: }
// IMPL-EMPTY:
+// IMPL-NEXT: static_assert(sizeof(llvm::tdl::Directive) == sizeof(int));
+// IMPL-NEXT: {{.*}} static const llvm::tdl::Directive LeafConstructTable[][2] = {
+// IMPL-NEXT: llvm::tdl::TDLD_dira, static_cast<llvm::tdl::Directive>(0),
+// IMPL-NEXT: };
+// IMPL-EMPTY:
+// IMPL-NEXT: {{.*}} static const int LeafConstructTableOrdering[] = {
+// IMPL-NEXT: 0,
+// IMPL-NEXT: };
+// IMPL-EMPTY:
// IMPL-NEXT: #endif // GEN_DIRECTIVES_IMPL
diff --git a/llvm/test/TableGen/directive2.td b/llvm/test/TableGen/directive2.td
index d6fa4835c8dfdc..1750022e1f94ea 100644
--- a/llvm/test/TableGen/directive2.td
+++ b/llvm/test/TableGen/directive2.td
@@ -45,6 +45,7 @@ def TDL_DirA : Directive<"dira"> {
// CHECK-NEXT: #define LLVM_Tdl_INC
// CHECK-EMPTY:
// CHECK-NEXT: #include "llvm/ADT/ArrayRef.h"
+// CHECK-NEXT: #include <cstddef>
// CHECK-EMPTY:
// CHECK-NEXT: namespace llvm {
// CHECK-NEXT: class StringRef;
@@ -88,7 +89,7 @@ def TDL_DirA : Directive<"dira"> {
// CHECK-NEXT: /// Return true if \p C is a valid clause for \p D in version \p Version.
// CHECK-NEXT: bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version);
// CHECK-EMPTY:
-// CHECK-NEXT: llvm::ArrayRef<Directive> getLeafConstructs(Directive D);
+// CHECK-NEXT: constexpr std::size_t getMaxLeafCount() { return 0; }
// CHECK-NEXT: Association getDirectiveAssociation(Directive D);
// CHECK-NEXT: } // namespace tdl
// CHECK-NEXT: } // namespace llvm
@@ -290,13 +291,6 @@ def TDL_DirA : Directive<"dira"> {
// IMPL-NEXT: llvm_unreachable("Invalid Tdl Directive kind");
// IMPL-NEXT: }
// IMPL-EMPTY:
-// IMPL-NEXT: llvm::ArrayRef<llvm::tdl::Directive> llvm::tdl::getLeafConstructs(llvm::tdl::Directive Dir) {
-// IMPL-NEXT: switch (Dir) {
-// IMPL-NEXT: default:
-// IMPL-NEXT: return ArrayRef<llvm::tdl::Directive>{};
-// IMPL-NEXT: } // switch (Dir)
-// IMPL-NEXT: }
-// IMPL-EMPTY:
// IMPL-NEXT: llvm::tdl::Association llvm::tdl::getDirectiveAssociation(llvm::tdl::Directive Dir) {
// IMPL-NEXT: switch (Dir) {
// IMPL-NEXT: case llvm::tdl::Directive::TDLD_dira:
@@ -305,4 +299,13 @@ def TDL_DirA : Directive<"dira"> {
// IMPL-NEXT: llvm_unreachable("Unexpected directive");
// IMPL-NEXT: }
// IMPL-EMPTY:
+// IMPL-NEXT: static_assert(sizeof(llvm::tdl::Directive) == sizeof(int));
+// IMPL-NEXT: {{.*}} static const llvm::tdl::Directive LeafConstructTable[][2] = {
+// IMPL-NEXT: llvm::tdl::TDLD_dira, static_cast<llvm::tdl::Directive>(0),
+// IMPL-NEXT: };
+// IMPL-EMPTY:
+// IMPL-NEXT: {{.*}} static const int LeafConstructTableOrdering[] = {
+// IMPL-NEXT: 0,
+// IMPL-NEXT: };
+// IMPL-EMPTY:
// IMPL-NEXT: #endif // GEN_DIRECTIVES_IMPL
diff --git a/llvm/unittests/Frontend/CMakeLists.txt b/llvm/unittests/Frontend/CMakeLists.txt
index c6f60142d6276a..ddb6a16cbb984e 100644
--- a/llvm/unittests/Frontend/CMakeLists.txt
+++ b/llvm/unittests/Frontend/CMakeLists.txt
@@ -14,6 +14,7 @@ add_llvm_unittest(LLVMFrontendTests
OpenMPContextTest.cpp
OpenMPIRBuilderTest.cpp
OpenMPParsingTest.cpp
+ OpenMPComposeTest.cpp
DEPENDS
acc_gen
diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
new file mode 100644
index 00000000000000..29b1be4eb3432c
--- /dev/null
+++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
@@ -0,0 +1,41 @@
+//===- llvm/unittests/Frontend/OpenMPComposeTest.cpp ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Frontend/OpenMP/OMP.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using namespace llvm::omp;
+
+TEST(Composition, GetLeafConstructs) {
+ ArrayRef<Directive> L1 = getLeafConstructs(OMPD_loop);
+ ASSERT_EQ(L1, (ArrayRef<Directive>{}));
+ ArrayRef<Directive> L2 = getLeafConstructs(OMPD_parallel_for);
+ ASSERT_EQ(L2, (ArrayRef<Directive>{OMPD_parallel, OMPD_for}));
+ ArrayRef<Directive> L3 = getLeafConstructs(OMPD_parallel_for_simd);
+ ASSERT_EQ(L3, (ArrayRef<Directive>{OMPD_parallel, OMPD_for, OMPD_simd}));
+}
+
+TEST(Composition, GetCompoundConstruct) {
+ Directive C1 =
+ getCompoundConstruct({OMPD_target, OMPD_teams, OMPD_distribute});
+ ASSERT_EQ(C1, OMPD_target_teams_distribute);
+ Directive C2 = getCompoundConstruct({OMPD_target});
+ ASSERT_EQ(C2, OMPD_target);
+ Directive C3 = getCompoundConstruct({OMPD_target, OMPD_masked});
+ ASSERT_EQ(C3, OMPD_unknown);
+ Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
+ ASSERT_EQ(C4, OMPD_target_teams_distribute);
+ Directive C5 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
+ ASSERT_EQ(C5, OMPD_target_teams_distribute);
+ Directive C6 = getCompoundConstruct({});
+ ASSERT_EQ(C6, OMPD_unknown);
+ Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd});
+ ASSERT_EQ(C7, OMPD_parallel_for_simd);
+}
diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp
index e0edf1720f8ac5..2d2b7748491897 100644
--- a/llvm/utils/TableGen/DirectiveEmitter.cpp
+++ b/llvm/utils/TableGen/DirectiveEmitter.cpp
@@ -20,6 +20,9 @@
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
+#include <numeric>
+#include <vector>
+
using namespace llvm;
namespace {
@@ -39,7 +42,8 @@ class IfDefScope {
};
} // namespace
-// Generate enum class
+// Generate enum class. Entries are emitted in the order in which they appear
+// in the `Records` vector.
static void GenerateEnumClass(const std::vector<Record *> &Records,
raw_ostream &OS, StringRef Enum, StringRef Prefix,
const DirectiveLanguage &DirLang,
@@ -175,6 +179,16 @@ bool DirectiveLanguage::HasValidityErrors() const {
return HasDuplicateClausesInDirectives(getDirectives());
}
+// Count the maximum number of leaf constituents per construct.
+static size_t GetMaxLeafCount(const DirectiveLanguage &DirLang) {
+ size_t MaxCount = 0;
+ for (Record *R : DirLang.getDirectives()) {
+ size_t Count = Directive{R}.getLeafConstructs().size();
+ MaxCount = std::max(MaxCount, Count);
+ }
+ return MaxCount;
+}
+
// Generate the declaration section for the enumeration in the directive
// language
static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
@@ -189,6 +203,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
if (DirLang.hasEnableBitmaskEnumInNamespace())
OS << "#include \"llvm/ADT/BitmaskEnum.h\"\n";
+ OS << "#include <cstddef>\n"; // for size_t
OS << "\n";
OS << "namespace llvm {\n";
OS << "class StringRef;\n";
@@ -244,7 +259,8 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
OS << "bool isAllowedClauseForDirective(Directive D, "
<< "Clause C, unsigned Version);\n";
OS << "\n";
- OS << "llvm::ArrayRef<Directive> getLeafConstructs(Directive D);\n";
+ OS << "constexpr std::size_t getMaxLeafCount() { return "
+ << GetMaxLeafCount(DirLang) << "; }\n";
OS << "Association getDirectiveAssociation(Directive D);\n";
if (EnumHelperFuncs.length() > 0) {
OS << EnumHelperFuncs;
@@ -396,6 +412,19 @@ GenerateCaseForVersionedClauses(const std::vector<Record *> &Clauses,
}
}
+static std::string GetDirectiveName(const DirectiveLanguage &DirLang,
+ const Record *Rec) {
+ Directive Dir{Rec};
+ return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::" +
+ DirLang.getDirectivePrefix() + Dir.getFormattedName())
+ .str();
+}
+
+static std::string GetDirectiveType(const DirectiveLanguage &DirLang) {
+ return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::Directive")
+ .str();
+}
+
// Generate the isAllowedClauseForDirective function implementation.
static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang,
raw_ostream &OS) {
@@ -450,77 +479,102 @@ static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang,
OS << "}\n"; // End of function isAllowedClauseForDirective
}
-// Generate the getLeafConstructs function implementation.
-static void GenerateGetLeafConstructs(const DirectiveLanguage &DirLang,
- raw_ostream &OS) {
- auto getQualifiedName = [&](StringRef Formatted) -> std::string {
- return (llvm::Twine("llvm::") + DirLang.getCppNamespace() +
- "::Directive::" + DirLang.getDirectivePrefix() + Formatted)
- .str();
- };
-
- // For each list of leaves, generate a static local object, then
- // return a reference to that object for a given directive, e.g.
+static void EmitLeafTable(const DirectiveLanguage &DirLang, raw_ostream &OS,
+ StringRef TableName) {
+ // The leaf constructs are emitted in a form of a 2D table, where each
+ // row corresponds to a directive (and there is a row for each directive).
//
- // static ListTy leafConstructs_A_B = { A, B };
- // static ListTy leafConstructs_C_D_E = { C, D, E };
- // switch (Dir) {
- // case A_B:
- // return leafConstructs_A_B;
- // case C_D_E:
- // return leafConstructs_C_D_E;
- // }
-
- // Map from a record that defines a directive to the name of the
- // local object with the list of its leaves.
- DenseMap<Record *, std::string> ListNames;
-
- std::string DirectiveTypeName =
- std::string("llvm::") + DirLang.getCppNamespace().str() + "::Directive";
-
- OS << '\n';
-
- // ArrayRef<...> llvm::<ns>::GetLeafConstructs(llvm::<ns>::Directive Dir)
- OS << "llvm::ArrayRef<" << DirectiveTypeName
- << "> llvm::" << DirLang.getCppNamespace() << "::getLeafConstructs("
- << DirectiveTypeName << " Dir) ";
- OS << "{\n";
-
- // Generate the locals.
- for (Record *R : DirLang.getDirectives()) {
- Directive Dir{R};
+ // Each row consists of
+ // - the id of the directive itself,
+ // - number of leaf constructs that will follow (0 for leafs),
+ // - ids of the leaf constructs (none if the directive is itself a leaf).
+ // The total number of these entries is at most MaxLeafCount+2. If this
+ // number is less than that, it is padded to occupy exactly MaxLeafCount+2
+ // entries in memory.
+ //
+ // The rows are stored in the table in the lexicographical order. This
+ // is intended to enable binary search when mapping a sequence of leafs
+ // back to the compound directive.
+ // The consequence of that is that in order to find a row corresponding
+ // to the given directive, we'd need to scan the first element of each
+ // row. To avoid this, an auxiliary ordering table is created, such that
+ // row for Dir_A = table[auxiliary[Dir_A]].
+
+ std::vector<Record *> Directives = DirLang.getDirectives();
+ DenseMap<Record *, size_t> DirId; // Record * -> llvm::omp::Directive
+
+ for (auto [Idx, Rec] : llvm::enumerate(Directives))
+ DirId.insert(std::make_pair(Rec, Idx));
+
+ using LeafList = std::vector<int>;
+ int MaxLeafCount = GetMaxLeafCount(DirLang);
+
+ // The initial leaf table, rows order is same as directive order.
+ std::vector<LeafList> LeafTable(Directives.size());
+ for (auto [Idx, Rec] : llvm::enumerate(Directives)) {
+ Directive Dir{Rec};
+ std::vector<Record *> Leaves = Dir.getLeafConstructs();
+
+ auto &List = LeafTable[Idx];
+ List.resize(MaxLeafCount + 2);
+ List[0] = Idx; // The id of the directive itself.
+ List[1] = Leaves.size(); // The number of leaves to follow.
+
+ for (int I = 0; I != MaxLeafCount; ++I)
+ List[I + 2] =
+ static_cast<size_t>(I) < Leaves.size() ? DirId.at(Leaves[I]) : -1;
+ }
- std::vector<Record *> LeafConstructs = Dir.getLeafConstructs();
- if (LeafConstructs.empty())
- continue;
+ // Avoid sorting the vector<vector> array, instead sort an index array.
+ // It will also be useful later to create the auxiliary indexing array.
+ std::vector<int> Ordering(Directives.size());
+ std::iota(Ordering.begin(), Ordering.end(), 0);
+
+ llvm::sort(Ordering, [&](int A, int B) {
+ auto &LeavesA = LeafTable[A];
+ auto &LeavesB = LeafTable[B];
+ if (LeavesA[1] == 0 && LeavesB[1] == 0)
+ return LeavesA[0] < LeavesB[0];
+ return std::lexicographical_compare(&LeavesA[2], &LeavesA[2] + LeavesA[1],
+ &LeavesB[2], &LeavesB[2] + LeavesB[1]);
+ });
- std::string ListName = "leafConstructs_" + Dir.getFormattedName();
- OS << " static const " << DirectiveTypeName << ' ' << ListName
- << "[] = {\n";
- for (Record *L : LeafConstructs) {
- Directive LeafDir{L};
- OS << " " << getQualifiedName(LeafDir.getFormattedName()) << ",\n";
+ // Emit the table
+
+ // The directives are emitted into a scoped enum, for which the underlying
+ // type is `int` (by default). The code above uses `int` to store directive
+ // ids, so make sure that we catch it when something changes in the
+ // underlying type.
+ std::string DirectiveType = GetDirectiveType(DirLang);
+ OS << "\nstatic_assert(sizeof(" << DirectiveType << ") == sizeof(int));\n";
+
+ OS << "[[maybe_unused]] static const " << DirectiveType << ' ' << TableName
+ << "[][" << MaxLeafCount + 2 << "] = {\n";
+ for (size_t I = 0, E = Directives.size(); I != E; ++I) {
+ auto &Leaves = LeafTable[Ordering[I]];
+ OS << " " << GetDirectiveName(DirLang, Directives[Leaves[0]]);
+ OS << ", static_cast<" << DirectiveType << ">(" << Leaves[1] << "),";
+ for (size_t I = 2, E = Leaves.size(); I != E; ++I) {
+ int Idx = Leaves[I];
+ if (Idx >= 0)
+ OS << ' ' << GetDirectiveName(DirLang, Directives[Leaves[I]]) << ',';
+ else
+ OS << " static_cast<" << DirectiveType << ">(-1),";
}
- OS << " };\n";
- ListNames.insert(std::make_pair(R, std::move(ListName)));
- }
-
- if (!ListNames.empty())
OS << '\n';
- OS << " switch (Dir) {\n";
- for (Record *R : DirLang.getDirectives()) {
- auto F = ListNames.find(R);
- if (F == ListNames.end())
- continue;
-
- Directive Dir{R};
- OS << " case " << getQualifiedName(Dir.getFormattedName()) << ":\n";
- OS << " return " << F->second << ";\n";
}
- OS << " default:\n";
- OS << " return ArrayRef<" << DirectiveTypeName << ">{};\n";
- OS << " } // switch (Dir)\n";
- OS << "}\n";
+ OS << "};\n\n";
+
+ // Emit the auxiliary index table: it's the inverse of the `Ordering`
+ // table above.
+ OS << "[[maybe_unused]] static const int " << TableName << "Ordering[] = {\n";
+ OS << " ";
+ std::vector<int> Reverse(Ordering.size());
+ for (int I = 0, E = Ordering.size(); I != E; ++I)
+ Reverse[Ordering[I]] = I;
+ for (int Idx : Reverse)
+ OS << ' ' << Idx << ',';
+ OS << "\n};\n";
}
static void GenerateGetDirectiveAssociation(const DirectiveLanguage &DirLang,
@@ -1105,11 +1159,11 @@ void EmitDirectivesBasicImpl(const DirectiveLanguage &DirLang,
// isAllowedClauseForDirective(Directive D, Clause C, unsigned Version)
GenerateIsAllowedClause(DirLang, OS);
- // getLeafConstructs(Directive D)
- GenerateGetLeafConstructs(DirLang, OS);
-
// getDirectiveAssociation(Directive D)
GenerateGetDirectiveAssociation(DirLang, OS);
+
+ // Leaf table for getLeafConstructs, etc.
+ EmitLeafTable(DirLang, OS, "LeafConstructTable");
}
// Generate the implemenation section for the enumeration in the directive
>From a889f3074fc8c4ae5c6d9480308be0501217b9ff Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Mon, 11 Mar 2024 12:55:38 -0500
Subject: [PATCH 07/17] [Frontend][OpenMP] Add functions for checking construct
type
Implement helper functions to identify leaf, composite, and combined
constructs.
---
llvm/include/llvm/Frontend/OpenMP/OMP.h | 4 +++
llvm/lib/Frontend/OpenMP/OMP.cpp | 25 ++++++++++++++++++
llvm/unittests/Frontend/OpenMPComposeTest.cpp | 26 +++++++++++++++++++
3 files changed, 55 insertions(+)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.h b/llvm/include/llvm/Frontend/OpenMP/OMP.h
index 4ed47f15dfe59e..ec8ae68f1c2ca0 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.h
@@ -20,6 +20,10 @@
namespace llvm::omp {
ArrayRef<Directive> getLeafConstructs(Directive D);
Directive getCompoundConstruct(ArrayRef<Directive> Parts);
+
+bool isLeafConstruct(Directive D);
+bool isCompositeConstruct(Directive D);
+bool isCombinedConstruct(Directive D);
} // namespace llvm::omp
#endif // LLVM_FRONTEND_OPENMP_OMP_H
diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp
index dd99d3d074fd1e..98d7c63bb8537e 100644
--- a/llvm/lib/Frontend/OpenMP/OMP.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMP.cpp
@@ -78,4 +78,29 @@ Directive getCompoundConstruct(ArrayRef<Directive> Parts) {
return Found;
return OMPD_unknown;
}
+
+bool isLeafConstruct(Directive D) { return getLeafConstructs(D).empty(); }
+
+bool isCompositeConstruct(Directive D) {
+ // OpenMP Spec 5.2: [17.3, 8-9]
+ // If directive-name-A and directive-name-B both correspond to loop-
+ // associated constructs then directive-name is a composite construct
+ llvm::ArrayRef<Directive> Leafs{getLeafConstructs(D)};
+ if (Leafs.empty())
+ return false;
+ if (getDirectiveAssociation(Leafs.front()) != Association::Loop)
+ return false;
+
+ size_t numLoopConstructs =
+ llvm::count_if(Leafs.drop_front(), [](Directive L) {
+ return getDirectiveAssociation(L) == Association::Loop;
+ });
+ return numLoopConstructs != 0;
+}
+
+bool isCombinedConstruct(Directive D) {
+ // OpenMP Spec 5.2: [17.3, 9-10]
+ // Otherwise directive-name is a combined construct.
+ return !getLeafConstructs(D).empty() && !isCompositeConstruct(D);
+}
} // namespace llvm::omp
diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
index 29b1be4eb3432c..cc02af8bf67c26 100644
--- a/llvm/unittests/Frontend/OpenMPComposeTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
@@ -39,3 +39,29 @@ TEST(Composition, GetCompoundConstruct) {
Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd});
ASSERT_EQ(C7, OMPD_parallel_for_simd);
}
+
+TEST(Composition, IsLeafConstruct) {
+ ASSERT_TRUE(isLeafConstruct(OMPD_loop));
+ ASSERT_TRUE(isLeafConstruct(OMPD_teams));
+ ASSERT_FALSE(isLeafConstruct(OMPD_for_simd));
+ ASSERT_FALSE(isLeafConstruct(OMPD_distribute_simd));
+}
+
+TEST(Composition, IsCompositeConstruct) {
+ ASSERT_TRUE(isCompositeConstruct(OMPD_distribute_simd));
+ ASSERT_FALSE(isCompositeConstruct(OMPD_for));
+ ASSERT_TRUE(isCompositeConstruct(OMPD_for_simd));
+ // directive-name-A = "parallel", directive-name-B = "for simd",
+ // only directive-name-A is loop-associated, so this is not a
+ // composite construct, even though "for simd" is.
+ ASSERT_FALSE(isCompositeConstruct(OMPD_parallel_for_simd));
+}
+
+TEST(Composition, IsCombinedConstruct) {
+ // "parallel for simd" is a combined construct, see comment in
+ // IsCompositeConstruct.
+ ASSERT_TRUE(isCombinedConstruct(OMPD_parallel_for_simd));
+ ASSERT_FALSE(isCombinedConstruct(OMPD_for_simd));
+ ASSERT_TRUE(isCombinedConstruct(OMPD_parallel_for));
+ ASSERT_FALSE(isCombinedConstruct(OMPD_parallel));
+}
>From 0d92781c7a52ed2fbab33ae6e7b3dae61cfd42ae Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Tue, 2 Apr 2024 08:20:15 -0500
Subject: [PATCH 08/17] Address review comments
---
llvm/lib/Frontend/OpenMP/OMP.cpp | 10 ++++++++--
llvm/unittests/Frontend/OpenMPComposeTest.cpp | 10 ++++------
2 files changed, 12 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp
index dd99d3d074fd1e..7504c9076fde1b 100644
--- a/llvm/lib/Frontend/OpenMP/OMP.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMP.cpp
@@ -27,8 +27,8 @@ using namespace llvm::omp;
namespace llvm::omp {
ArrayRef<Directive> getLeafConstructs(Directive D) {
- auto Idx = static_cast<int>(D);
- if (Idx < 0 || Idx >= static_cast<int>(Directive_enumSize))
+ auto Idx = static_cast<std::size_t>(D);
+ if (Idx >= Directive_enumSize)
return {};
const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]];
return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1]));
@@ -50,6 +50,12 @@ Directive getCompoundConstruct(ArrayRef<Directive> Parts) {
RawLeafs.push_back(P);
}
+ // RawLeafs will be used as key in the binary search. The search doesn't
+ // guarantee that the exact same entry will be found (since RawLeafs may
+ // not correspond to any compound directive). Because of that, we will
+ // need to compare the search result with the given set of leafs.
+ // Also, if there is only one leaf in the list, it corresponds to itself,
+ // no search is necessary.
auto GivenLeafs{ArrayRef<Directive>(RawLeafs).drop_front(2)};
if (GivenLeafs.size() == 1)
return GivenLeafs.front();
diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
index 29b1be4eb3432c..c3e0880ece8641 100644
--- a/llvm/unittests/Frontend/OpenMPComposeTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
@@ -32,10 +32,8 @@ TEST(Composition, GetCompoundConstruct) {
ASSERT_EQ(C3, OMPD_unknown);
Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
ASSERT_EQ(C4, OMPD_target_teams_distribute);
- Directive C5 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
- ASSERT_EQ(C5, OMPD_target_teams_distribute);
- Directive C6 = getCompoundConstruct({});
- ASSERT_EQ(C6, OMPD_unknown);
- Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd});
- ASSERT_EQ(C7, OMPD_parallel_for_simd);
+ Directive C5 = getCompoundConstruct({});
+ ASSERT_EQ(C5, OMPD_unknown);
+ Directive C6 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd});
+ ASSERT_EQ(C6, OMPD_parallel_for_simd);
}
>From 46770f8dfe25528e970e5908aae8b2a788655bfc Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Fri, 29 Mar 2024 09:20:41 -0500
Subject: [PATCH 09/17] [flang][OpenMP] Move clause/object conversion to happen
early, in genOMP
This removes the last use of genOmpObjectList2, which has now been removed.
---
flang/lib/Lower/OpenMP/ClauseProcessor.h | 5 +-
flang/lib/Lower/OpenMP/DataSharingProcessor.h | 5 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 424 +++++++++---------
flang/lib/Lower/OpenMP/Utils.cpp | 30 +-
flang/lib/Lower/OpenMP/Utils.h | 6 +-
5 files changed, 218 insertions(+), 252 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index db7a1b8335f818..f4d659b70cfee7 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -49,9 +49,8 @@ class ClauseProcessor {
public:
ClauseProcessor(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses)
- : converter(converter), semaCtx(semaCtx),
- clauses(makeClauses(clauses, semaCtx)) {}
+ const List<Clause> &clauses)
+ : converter(converter), semaCtx(semaCtx), clauses(clauses) {}
// 'Unique' clauses: They can appear at most once in the clause list.
bool processCollapse(
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
index c11ee299c5d085..ef7b14327278e3 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
@@ -78,13 +78,12 @@ class DataSharingProcessor {
public:
DataSharingProcessor(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &opClauseList,
+ const List<Clause> &clauses,
Fortran::lower::pft::Evaluation &eval,
bool useDelayedPrivatization = false,
Fortran::lower::SymMap *symTable = nullptr)
: hasLastPrivateOp(false), converter(converter),
- firOpBuilder(converter.getFirOpBuilder()),
- clauses(omp::makeClauses(opClauseList, semaCtx)), eval(eval),
+ firOpBuilder(converter.getFirOpBuilder()), clauses(clauses), eval(eval),
useDelayedPrivatization(useDelayedPrivatization), symTable(symTable) {}
// Privatisation is split into two steps.
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index edae453972d3d9..23dc25ac1ae9a1 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -17,6 +17,7 @@
#include "DataSharingProcessor.h"
#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
+#include "Utils.h"
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
@@ -310,14 +311,15 @@ static void getDeclareTargetInfo(
} else if (const auto *clauseList{
Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
spec.u)}) {
- if (clauseList->v.empty()) {
+ List<Clause> clauses = makeClauses(*clauseList, semaCtx);
+ if (clauses.empty()) {
// Case: declare target, implicit capture of function
symbolAndClause.emplace_back(
mlir::omp::DeclareTargetCaptureClause::to,
eval.getOwningProcedure()->getSubprogramSymbol());
}
- ClauseProcessor cp(converter, semaCtx, *clauseList);
+ ClauseProcessor cp(converter, semaCtx, clauses);
cp.processDeviceType(clauseOps);
cp.processEnter(symbolAndClause);
cp.processLink(symbolAndClause);
@@ -597,14 +599,11 @@ static void removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) {
// TODO: Generate the reduction operation during lowering instead of creating
// and removing operations since this is not a robust approach. Also, removing
// ops in the builder (instead of a rewriter) is probably not the best approach.
-static void
-genOpenMPReduction(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauseList) {
+static void genOpenMPReduction(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const List<Clause> &clauses) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- List<Clause> clauses{makeClauses(clauseList, semaCtx)};
-
for (const Clause &clause : clauses) {
if (const auto &reductionClause =
std::get_if<clause::Reduction>(&clause.u)) {
@@ -812,7 +811,7 @@ struct OpWithBodyGenInfo {
return *this;
}
- OpWithBodyGenInfo &setClauses(const Fortran::parser::OmpClauseList *value) {
+ OpWithBodyGenInfo &setClauses(const List<Clause> *value) {
clauses = value;
return *this;
}
@@ -848,7 +847,7 @@ struct OpWithBodyGenInfo {
/// [in] is this an outer operation - prevents privatization.
bool outerCombined = false;
/// [in] list of clauses to process.
- const Fortran::parser::OmpClauseList *clauses = nullptr;
+ const List<Clause> *clauses = nullptr;
/// [in] if provided, processes the construct's data-sharing attributes.
DataSharingProcessor *dsp = nullptr;
/// [in] if provided, list of reduction symbols
@@ -1226,36 +1225,33 @@ static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) {
// Code generation functions for clauses
//===----------------------------------------------------------------------===//
-static void genCriticalDeclareClauses(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- mlir::omp::CriticalClauseOps &clauseOps, llvm::StringRef name) {
+static void
+genCriticalDeclareClauses(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const List<Clause> &clauses, mlir::Location loc,
+ mlir::omp::CriticalClauseOps &clauseOps,
+ llvm::StringRef name) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processHint(clauseOps);
clauseOps.nameAttr =
mlir::StringAttr::get(converter.getFirOpBuilder().getContext(), name);
}
-static void genFlushClauses(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- const std::optional<Fortran::parser::OmpObjectList> &objects,
- const std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>
- &clauses,
- mlir::Location loc, llvm::SmallVectorImpl<mlir::Value> &operandRange) {
- if (objects)
- genObjectList2(*objects, converter, operandRange);
-
- if (clauses && clauses->size() > 0)
+static void genFlushClauses(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const ObjectList &objects,
+ const List<Clause> &clauses, mlir::Location loc,
+ llvm::SmallVectorImpl<mlir::Value> &operandRange) {
+ genObjectList(objects, converter, operandRange);
+
+ if (clauses.size() > 0)
TODO(converter.getCurrentLocation(), "Handle OmpMemoryOrderClause");
}
static void
genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::OrderedRegionClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processTODO<clause::Simd>(loc, llvm::omp::Directive::OMPD_ordered);
@@ -1264,9 +1260,9 @@ genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter,
static void genParallelClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- bool processReduction, mlir::omp::ParallelClauseOps &clauseOps,
+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, bool processReduction,
+ mlir::omp::ParallelClauseOps &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
@@ -1286,8 +1282,7 @@ static void genParallelClauses(
static void genSectionsClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
bool clausesFromBeginSections,
mlir::omp::SectionsClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
@@ -1304,9 +1299,8 @@ static void genSimdLoopClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::StatementContext &stmtCtx,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- mlir::omp::SimdLoopClauseOps &clauseOps,
+ Fortran::lower::pft::Evaluation &eval, const List<Clause> &clauses,
+ mlir::Location loc, mlir::omp::SimdLoopClauseOps &clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processCollapse(loc, eval, clauseOps, iv);
@@ -1324,9 +1318,8 @@ static void genSimdLoopClauses(
static void genSingleClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &beginClauses,
- const Fortran::parser::OmpClauseList &endClauses,
- mlir::Location loc,
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc,
mlir::omp::SingleClauseOps &clauseOps) {
ClauseProcessor bcp(converter, semaCtx, beginClauses);
bcp.processAllocate(clauseOps);
@@ -1340,9 +1333,8 @@ static void genSingleClauses(Fortran::lower::AbstractConverter &converter,
static void genTargetClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- bool processHostOnlyClauses, bool processReduction,
+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, bool processHostOnlyClauses, bool processReduction,
mlir::omp::TargetClauseOps &clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &mapSyms,
llvm::SmallVectorImpl<mlir::Location> &mapSymLocs,
@@ -1368,9 +1360,8 @@ static void genTargetClauses(
static void genTargetDataClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- mlir::omp::TargetDataClauseOps &clauseOps,
+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, mlir::omp::TargetDataClauseOps &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms) {
@@ -1401,9 +1392,8 @@ static void genTargetDataClauses(
static void genTargetEnterExitUpdateDataClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- llvm::omp::Directive directive,
+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, llvm::omp::Directive directive,
mlir::omp::TargetEnterExitUpdateDataClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processDepend(clauseOps);
@@ -1422,8 +1412,7 @@ static void genTargetEnterExitUpdateDataClauses(
static void genTaskClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
@@ -1442,8 +1431,7 @@ static void genTaskClauses(Fortran::lower::AbstractConverter &converter,
static void genTaskgroupClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskgroupClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
@@ -1453,8 +1441,7 @@ static void genTaskgroupClauses(Fortran::lower::AbstractConverter &converter,
static void genTaskwaitClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskwaitClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processTODO<clause::Depend, clause::Nowait>(
@@ -1464,8 +1451,7 @@ static void genTaskwaitClauses(Fortran::lower::AbstractConverter &converter,
static void genTeamsClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TeamsClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
@@ -1482,9 +1468,8 @@ static void genWsloopClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::StatementContext &stmtCtx,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauses,
- const Fortran::parser::OmpClauseList *endClauses, mlir::Location loc,
+ Fortran::lower::pft::Evaluation &eval, const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc,
mlir::omp::WsloopClauseOps &clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
@@ -1501,8 +1486,8 @@ static void genWsloopClauses(
if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
- if (endClauses) {
- ClauseProcessor ecp(converter, semaCtx, *endClauses);
+ if (!endClauses.empty()) {
+ ClauseProcessor ecp(converter, semaCtx, endClauses);
ecp.processNowait(clauseOps);
}
@@ -1525,8 +1510,7 @@ static mlir::omp::CriticalOp
genCriticalOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList,
+ mlir::Location loc, const List<Clause> &clauses,
const std::optional<Fortran::parser::Name> &name) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::FlatSymbolRefAttr nameAttr;
@@ -1537,7 +1521,7 @@ genCriticalOp(Fortran::lower::AbstractConverter &converter,
auto global = mod.lookupSymbol<mlir::omp::CriticalDeclareOp>(nameStr);
if (!global) {
mlir::omp::CriticalClauseOps clauseOps;
- genCriticalDeclareClauses(converter, semaCtx, clauseList, loc, clauseOps,
+ genCriticalDeclareClauses(converter, semaCtx, clauses, loc, clauseOps,
nameStr);
mlir::OpBuilder modBuilder(mod.getBodyRegion());
@@ -1556,8 +1540,7 @@ static mlir::omp::DistributeOp
genDistributeOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
TODO(loc, "Distribute construct");
return nullptr;
}
@@ -1566,12 +1549,9 @@ static mlir::omp::FlushOp
genFlushOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const std::optional<Fortran::parser::OmpObjectList> &objectList,
- const std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>
- &clauseList) {
+ const ObjectList &objects, const List<Clause> &clauses) {
llvm::SmallVector<mlir::Value> operandRange;
- genFlushClauses(converter, semaCtx, objectList, clauseList, loc,
- operandRange);
+ genFlushClauses(converter, semaCtx, objects, clauses, loc, operandRange);
return converter.getFirOpBuilder().create<mlir::omp::FlushOp>(
converter.getCurrentLocation(), operandRange);
@@ -1591,7 +1571,7 @@ static mlir::omp::OrderedOp
genOrderedOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ const List<Clause> &clauses) {
TODO(loc, "OMPD_ordered");
return nullptr;
}
@@ -1600,10 +1580,9 @@ static mlir::omp::OrderedRegionOp
genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
mlir::omp::OrderedRegionClauseOps clauseOps;
- genOrderedRegionClauses(converter, semaCtx, clauseList, loc, clauseOps);
+ genOrderedRegionClauses(converter, semaCtx, clauses, loc, clauseOps);
return genOpWithBody<mlir::omp::OrderedRegionOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(genNested),
@@ -1615,8 +1594,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList,
+ mlir::Location loc, const List<Clause> &clauses,
bool outerCombined = false) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
@@ -1624,7 +1602,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> privateSyms;
llvm::SmallVector<mlir::Type> reductionTypes;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
- genParallelClauses(converter, semaCtx, stmtCtx, clauseList, loc,
+ genParallelClauses(converter, semaCtx, stmtCtx, clauses, loc,
/*processReduction=*/!outerCombined, clauseOps,
reductionTypes, reductionSyms);
@@ -1637,7 +1615,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
.setOuterCombined(outerCombined)
- .setClauses(&clauseList)
+ .setClauses(&clauses)
.setReductions(&reductionSyms, &reductionTypes)
.setGenRegionEntryCb(reductionCallback);
@@ -1645,7 +1623,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
bool privatize = !outerCombined;
- DataSharingProcessor dsp(converter, semaCtx, clauseList, eval,
+ DataSharingProcessor dsp(converter, semaCtx, clauses, eval,
/*useDelayedPrivatization=*/true, &symTable);
if (privatize)
@@ -1692,14 +1670,13 @@ static mlir::omp::SectionOp
genSectionOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
// Currently only private/firstprivate clause is handled, and
// all privatization is done within `omp.section` operations.
return genOpWithBody<mlir::omp::SectionOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
- .setClauses(&clauseList));
+ .setClauses(&clauses));
}
static mlir::omp::SectionsOp
@@ -1716,18 +1693,17 @@ static mlir::omp::SimdLoopOp
genSimdLoopOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
- DataSharingProcessor dsp(converter, semaCtx, clauseList, eval);
+ const List<Clause> &clauses) {
+ DataSharingProcessor dsp(converter, semaCtx, clauses, eval);
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
mlir::omp::SimdLoopClauseOps clauseOps;
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
- genSimdLoopClauses(converter, semaCtx, stmtCtx, eval, clauseList, loc,
- clauseOps, iv);
+ genSimdLoopClauses(converter, semaCtx, stmtCtx, eval, clauses, loc, clauseOps,
+ iv);
- auto *nestedEval =
- getCollapsedLoopEval(eval, Fortran::lower::getCollapseValue(clauseList));
+ auto *nestedEval = getCollapsedLoopEval(eval, getCollapseValue(clauses));
auto ivCallback = [&](mlir::Operation *op) {
return genLoopVars(op, converter, loc, iv);
@@ -1735,7 +1711,7 @@ genSimdLoopOp(Fortran::lower::AbstractConverter &converter,
return genOpWithBody<mlir::omp::SimdLoopOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
- .setClauses(&clauseList)
+ .setClauses(&clauses)
.setDataSharingProcessor(&dsp)
.setGenRegionEntryCb(ivCallback),
clauseOps);
@@ -1745,17 +1721,16 @@ static mlir::omp::SingleOp
genSingleOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList &endClauseList) {
+ mlir::Location loc, const List<Clause> &beginClauses,
+ const List<Clause> &endClauses) {
mlir::omp::SingleClauseOps clauseOps;
- genSingleClauses(converter, semaCtx, beginClauseList, endClauseList, loc,
+ genSingleClauses(converter, semaCtx, beginClauses, endClauses, loc,
clauseOps);
return genOpWithBody<mlir::omp::SingleOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
- .setClauses(&beginClauseList),
+ .setClauses(&beginClauses),
clauseOps);
}
@@ -1763,8 +1738,7 @@ static mlir::omp::TargetOp
genTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList,
+ mlir::Location loc, const List<Clause> &clauses,
bool outerCombined = false) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
@@ -1777,7 +1751,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSyms;
llvm::SmallVector<mlir::Location> mapSymLocs;
llvm::SmallVector<mlir::Type> mapSymTypes;
- genTargetClauses(converter, semaCtx, stmtCtx, clauseList, loc,
+ genTargetClauses(converter, semaCtx, stmtCtx, clauses, loc,
processHostOnlyClauses, /*processReduction=*/outerCombined,
clauseOps, mapSyms, mapSymLocs, mapSymTypes);
@@ -1875,14 +1849,13 @@ static mlir::omp::TargetDataOp
genTargetDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
Fortran::lower::StatementContext stmtCtx;
mlir::omp::TargetDataClauseOps clauseOps;
llvm::SmallVector<mlir::Type> useDeviceTypes;
llvm::SmallVector<mlir::Location> useDeviceLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSyms;
- genTargetDataClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps,
+ genTargetDataClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps,
useDeviceTypes, useDeviceLocs, useDeviceSyms);
auto targetDataOp =
@@ -1894,11 +1867,11 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter,
return targetDataOp;
}
-template <typename OpTy>
-static OpTy genTargetEnterExitUpdateDataOp(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx, mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+template <typename OpTy> static OpTy
+genTargetEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ mlir::Location loc,
+ const List<Clause> &clauses) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
@@ -1915,8 +1888,8 @@ static OpTy genTargetEnterExitUpdateDataOp(
}
mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps;
- genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauseList,
- loc, directive, clauseOps);
+ genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauses, loc,
+ directive, clauseOps);
return firOpBuilder.create<OpTy>(loc, clauseOps);
}
@@ -1925,16 +1898,15 @@ static mlir::omp::TaskOp
genTaskOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
Fortran::lower::StatementContext stmtCtx;
mlir::omp::TaskClauseOps clauseOps;
- genTaskClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps);
+ genTaskClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps);
return genOpWithBody<mlir::omp::TaskOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
- .setClauses(&clauseList),
+ .setClauses(&clauses),
clauseOps);
}
@@ -1942,15 +1914,14 @@ static mlir::omp::TaskgroupOp
genTaskgroupOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
mlir::omp::TaskgroupClauseOps clauseOps;
- genTaskgroupClauses(converter, semaCtx, clauseList, loc, clauseOps);
+ genTaskgroupClauses(converter, semaCtx, clauses, loc, clauseOps);
return genOpWithBody<mlir::omp::TaskgroupOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
- .setClauses(&clauseList),
+ .setClauses(&clauses),
clauseOps);
}
@@ -1958,7 +1929,7 @@ static mlir::omp::TaskloopOp
genTaskloopOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ const List<Clause> &clauses) {
TODO(loc, "Taskloop construct");
}
@@ -1966,9 +1937,9 @@ static mlir::omp::TaskwaitOp
genTaskwaitOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ const List<Clause> &clauses) {
mlir::omp::TaskwaitClauseOps clauseOps;
- genTaskwaitClauses(converter, semaCtx, clauseList, loc, clauseOps);
+ genTaskwaitClauses(converter, semaCtx, clauses, loc, clauseOps);
return converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(loc,
clauseOps);
}
@@ -1984,17 +1955,17 @@ static mlir::omp::TeamsOp
genTeamsOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc, const Fortran::parser::OmpClauseList &clauseList,
+ mlir::Location loc, const List<Clause> &clauses,
bool outerCombined = false) {
Fortran::lower::StatementContext stmtCtx;
mlir::omp::TeamsClauseOps clauseOps;
- genTeamsClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps);
+ genTeamsClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps);
return genOpWithBody<mlir::omp::TeamsOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
.setOuterCombined(outerCombined)
- .setClauses(&clauseList),
+ .setClauses(&clauses),
clauseOps);
}
@@ -2002,9 +1973,8 @@ static mlir::omp::WsloopOp
genWsloopOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList) {
- DataSharingProcessor dsp(converter, semaCtx, beginClauseList, eval);
+ const List<Clause> &beginClauses, const List<Clause> &endClauses) {
+ DataSharingProcessor dsp(converter, semaCtx, beginClauses, eval);
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
@@ -2012,12 +1982,10 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
llvm::SmallVector<mlir::Type> reductionTypes;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
- genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauseList,
- endClauseList, loc, clauseOps, iv, reductionTypes,
- reductionSyms);
+ genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauses, endClauses,
+ loc, clauseOps, iv, reductionTypes, reductionSyms);
- auto *nestedEval = getCollapsedLoopEval(
- eval, Fortran::lower::getCollapseValue(beginClauseList));
+ auto *nestedEval = getCollapsedLoopEval(eval, getCollapseValue(beginClauses));
auto ivCallback = [&](mlir::Operation *op) {
return genLoopAndReductionVars(op, converter, loc, iv, reductionSyms,
@@ -2026,7 +1994,7 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
return genOpWithBody<mlir::omp::WsloopOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
- .setClauses(&beginClauseList)
+ .setClauses(&beginClauses)
.setDataSharingProcessor(&dsp)
.setReductions(&reductionSyms, &reductionTypes)
.setGenRegionEntryCb(ivCallback),
@@ -2041,8 +2009,8 @@ static void genCompositeDistributeParallelDo(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc) {
TODO(loc, "Composite DISTRIBUTE PARALLEL DO");
}
@@ -2050,8 +2018,8 @@ static void genCompositeDistributeParallelDoSimd(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc) {
TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD");
}
@@ -2059,8 +2027,8 @@ static void genCompositeDistributeSimd(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc) {
TODO(loc, "Composite DISTRIBUTE SIMD");
}
@@ -2068,10 +2036,10 @@ static void
genCompositeDoSimd(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList,
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses,
mlir::Location loc) {
- ClauseProcessor cp(converter, semaCtx, beginClauseList);
+ ClauseProcessor cp(converter, semaCtx, beginClauses);
cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
clause::Order, clause::Safelen, clause::Simdlen>(
loc, llvm::omp::OMPD_do_simd);
@@ -2083,15 +2051,15 @@ genCompositeDoSimd(Fortran::lower::AbstractConverter &converter,
// When support for vectorization is enabled, then we need to add handling of
// if clause. Currently if clause can be skipped because we always assume
// SIMD length = 1.
- genWsloopOp(converter, semaCtx, eval, loc, beginClauseList, endClauseList);
+ genWsloopOp(converter, semaCtx, eval, loc, beginClauses, endClauses);
}
static void
genCompositeTaskloopSimd(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList,
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses,
mlir::Location loc) {
TODO(loc, "Composite TASKLOOP SIMD");
}
@@ -2201,8 +2169,9 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
const auto &directive =
std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
simpleStandaloneConstruct.t);
- const auto &clauseList =
- std::get<Fortran::parser::OmpClauseList>(simpleStandaloneConstruct.t);
+ List<Clause> clauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(simpleStandaloneConstruct.t),
+ semaCtx);
mlir::Location currentLocation = converter.genLocation(directive.source);
switch (directive.v) {
@@ -2212,29 +2181,29 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
genBarrierOp(converter, semaCtx, eval, currentLocation);
break;
case llvm::omp::Directive::OMPD_taskwait:
- genTaskwaitOp(converter, semaCtx, eval, currentLocation, clauseList);
+ genTaskwaitOp(converter, semaCtx, eval, currentLocation, clauses);
break;
case llvm::omp::Directive::OMPD_taskyield:
genTaskyieldOp(converter, semaCtx, eval, currentLocation);
break;
case llvm::omp::Directive::OMPD_target_data:
genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true,
- currentLocation, clauseList);
+ currentLocation, clauses);
break;
case llvm::omp::Directive::OMPD_target_enter_data:
genTargetEnterExitUpdateDataOp<mlir::omp::TargetEnterDataOp>(
- converter, semaCtx, currentLocation, clauseList);
+ converter, semaCtx, currentLocation, clauses);
break;
case llvm::omp::Directive::OMPD_target_exit_data:
genTargetEnterExitUpdateDataOp<mlir::omp::TargetExitDataOp>(
- converter, semaCtx, currentLocation, clauseList);
+ converter, semaCtx, currentLocation, clauses);
break;
case llvm::omp::Directive::OMPD_target_update:
genTargetEnterExitUpdateDataOp<mlir::omp::TargetUpdateOp>(
- converter, semaCtx, currentLocation, clauseList);
+ converter, semaCtx, currentLocation, clauses);
break;
case llvm::omp::Directive::OMPD_ordered:
- genOrderedOp(converter, semaCtx, eval, currentLocation, clauseList);
+ genOrderedOp(converter, semaCtx, eval, currentLocation, clauses);
break;
}
}
@@ -2251,8 +2220,14 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const auto &clauseList =
std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
flushConstruct.t);
+ ObjectList objects =
+ objectList ? makeObjects(*objectList, semaCtx) : ObjectList{};
+ List<Clause> clauses =
+ clauseList ? makeList(*clauseList,
+ [&](auto &&s) { return makeClause(s.v, semaCtx); })
+ : List<Clause>{};
mlir::Location currentLocation = converter.genLocation(verbatim.source);
- genFlushOp(converter, semaCtx, eval, currentLocation, objectList, clauseList);
+ genFlushOp(converter, semaCtx, eval, currentLocation, objects, clauses);
}
static void
@@ -2357,44 +2332,44 @@ genOMP(Fortran::lower::AbstractConverter &converter,
converter.genLocation(beginBlockDirective.source);
const auto origDirective =
std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t).v;
- const auto &beginClauseList =
- std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
- const auto &endClauseList =
- std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t);
+ List<Clause> beginClauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t), semaCtx);
+ List<Clause> endClauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t), semaCtx);
assert(llvm::omp::blockConstructSet.test(origDirective) &&
"Expected block construct");
- for (const Fortran::parser::OmpClause &clause : beginClauseList.v) {
+ for (const Clause &clause : beginClauses) {
mlir::Location clauseLocation = converter.genLocation(clause.source);
- if (!std::get_if<Fortran::parser::OmpClause::If>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::NumThreads>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::ProcBind>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Allocate>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Default>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Final>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Priority>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Depend>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Firstprivate>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Simd>(&clause.u)) {
+ if (!std::get_if<clause::If>(&clause.u) &&
+ !std::get_if<clause::NumThreads>(&clause.u) &&
+ !std::get_if<clause::ProcBind>(&clause.u) &&
+ !std::get_if<clause::Allocate>(&clause.u) &&
+ !std::get_if<clause::Default>(&clause.u) &&
+ !std::get_if<clause::Final>(&clause.u) &&
+ !std::get_if<clause::Priority>(&clause.u) &&
+ !std::get_if<clause::Reduction>(&clause.u) &&
+ !std::get_if<clause::Depend>(&clause.u) &&
+ !std::get_if<clause::Private>(&clause.u) &&
+ !std::get_if<clause::Firstprivate>(&clause.u) &&
+ !std::get_if<clause::Copyin>(&clause.u) &&
+ !std::get_if<clause::Shared>(&clause.u) &&
+ !std::get_if<clause::Threads>(&clause.u) &&
+ !std::get_if<clause::Map>(&clause.u) &&
+ !std::get_if<clause::UseDevicePtr>(&clause.u) &&
+ !std::get_if<clause::UseDeviceAddr>(&clause.u) &&
+ !std::get_if<clause::ThreadLimit>(&clause.u) &&
+ !std::get_if<clause::NumTeams>(&clause.u) &&
+ !std::get_if<clause::Simd>(&clause.u)) {
TODO(clauseLocation, "OpenMP Block construct clause");
}
}
- for (const auto &clause : endClauseList.v) {
+ for (const Clause &clause : endClauses) {
mlir::Location clauseLocation = converter.genLocation(clause.source);
- if (!std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Copyprivate>(&clause.u))
+ if (!std::get_if<clause::Nowait>(&clause.u) &&
+ !std::get_if<clause::Copyprivate>(&clause.u))
TODO(clauseLocation, "OpenMP Block construct clause");
}
@@ -2413,44 +2388,44 @@ genOMP(Fortran::lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_ordered:
// 2.17.9 ORDERED construct.
genOrderedRegionOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_parallel:
// 2.6 PARALLEL construct.
genParallelOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, beginClauseList, outerCombined);
+ currentLocation, beginClauses, outerCombined);
break;
case llvm::omp::Directive::OMPD_single:
// 2.8.2 SINGLE construct.
genSingleOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList, endClauseList);
+ beginClauses, endClauses);
break;
case llvm::omp::Directive::OMPD_target:
// 2.12.5 TARGET construct.
genTargetOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList, outerCombined);
+ beginClauses, outerCombined);
break;
case llvm::omp::Directive::OMPD_target_data:
// 2.12.2 TARGET DATA construct.
genTargetDataOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_task:
// 2.10.1 TASK construct.
genTaskOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_taskgroup:
// 2.17.6 TASKGROUP construct.
genTaskgroupOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_teams:
// 2.7 TEAMS construct.
// FIXME Pass the outerCombined argument or rename it to better describe
// what it represents if it must always be `false` in this context.
genTeamsOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_workshare:
// 2.8.3 WORKSHARE construct.
@@ -2458,7 +2433,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
// implementation for this feature will come later. For the codes
// that use this construct, add a single construct for now.
genSingleOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList, endClauseList);
+ beginClauses, endClauses);
break;
default:
llvm_unreachable("Unexpected block construct");
@@ -2476,11 +2451,12 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
const auto &cd =
std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t);
- const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
+ List<Clause> clauses =
+ makeClauses(std::get<Fortran::parser::OmpClauseList>(cd.t), semaCtx);
const auto &name = std::get<std::optional<Fortran::parser::Name>>(cd.t);
mlir::Location currentLocation = converter.getCurrentLocation();
genCriticalOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
- clauseList, name);
+ clauses, name);
}
static void
@@ -2499,8 +2475,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
const auto &beginLoopDirective =
std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t);
- const auto &beginClauseList =
- std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t);
+ List<Clause> beginClauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t), semaCtx);
mlir::Location currentLocation =
converter.genLocation(beginLoopDirective.source);
const auto origDirective =
@@ -2509,15 +2485,15 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
assert(llvm::omp::loopConstructSet.test(origDirective) &&
"Expected loop construct");
- const auto *endClauseList = [&]() {
- using RetTy = const Fortran::parser::OmpClauseList *;
+ List<Clause> endClauses = [&]() {
if (auto &endLoopDirective =
std::get<std::optional<Fortran::parser::OmpEndLoopDirective>>(
loopConstruct.t)) {
- return RetTy(
- &std::get<Fortran::parser::OmpClauseList>((*endLoopDirective).t));
+ return makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(endLoopDirective->t),
+ semaCtx);
}
- return RetTy();
+ return List<Clause>{};
}();
std::optional<llvm::omp::Directive> nextDir = origDirective;
@@ -2530,29 +2506,29 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_distribute_parallel_do:
// 2.9.4.3 DISTRIBUTE PARALLEL Worksharing-Loop construct.
genCompositeDistributeParallelDo(converter, semaCtx, eval,
- beginClauseList, endClauseList,
+ beginClauses, endClauses,
currentLocation);
break;
case llvm::omp::Directive::OMPD_distribute_parallel_do_simd:
// 2.9.4.4 DISTRIBUTE PARALLEL Worksharing-Loop SIMD construct.
genCompositeDistributeParallelDoSimd(converter, semaCtx, eval,
- beginClauseList, endClauseList,
+ beginClauses, endClauses,
currentLocation);
break;
case llvm::omp::Directive::OMPD_distribute_simd:
// 2.9.4.2 DISTRIBUTE SIMD construct.
- genCompositeDistributeSimd(converter, semaCtx, eval, beginClauseList,
- endClauseList, currentLocation);
+ genCompositeDistributeSimd(converter, semaCtx, eval, beginClauses,
+ endClauses, currentLocation);
break;
case llvm::omp::Directive::OMPD_do_simd:
// 2.9.3.2 Worksharing-Loop SIMD construct.
- genCompositeDoSimd(converter, semaCtx, eval, beginClauseList,
- endClauseList, currentLocation);
+ genCompositeDoSimd(converter, semaCtx, eval, beginClauses,
+ endClauses, currentLocation);
break;
case llvm::omp::Directive::OMPD_taskloop_simd:
// 2.10.3 TASKLOOP SIMD construct.
- genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauseList,
- endClauseList, currentLocation);
+ genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauses,
+ endClauses, currentLocation);
break;
default:
llvm_unreachable("Unexpected composite construct");
@@ -2563,12 +2539,12 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_distribute:
// 2.9.4.1 DISTRIBUTE construct.
genDistributeOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_do:
// 2.9.2 Worksharing-Loop construct.
- genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList,
- endClauseList);
+ genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauses,
+ endClauses);
break;
case llvm::omp::Directive::OMPD_parallel:
// 2.6 PARALLEL construct.
@@ -2577,24 +2553,24 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
// Maybe rename the argument if it represents something else or
// initialize it properly.
genParallelOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, beginClauseList,
+ currentLocation, beginClauses,
/*outerCombined=*/true);
break;
case llvm::omp::Directive::OMPD_simd:
// 2.9.3.1 SIMD construct.
genSimdLoopOp(converter, semaCtx, eval, currentLocation,
- beginClauseList);
- genOpenMPReduction(converter, semaCtx, beginClauseList);
+ beginClauses);
+ genOpenMPReduction(converter, semaCtx, beginClauses);
break;
case llvm::omp::Directive::OMPD_target:
// 2.12.5 TARGET construct.
genTargetOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList, /*outerCombined=*/true);
+ beginClauses, /*outerCombined=*/true);
break;
case llvm::omp::Directive::OMPD_taskloop:
// 2.10.2 TASKLOOP construct.
genTaskloopOp(converter, semaCtx, eval, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_teams:
// 2.7 TEAMS construct.
@@ -2603,7 +2579,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
// Maybe rename the argument if it represents something else or
// initialize it properly.
genTeamsOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList, /*outerCombined=*/true);
+ beginClauses, /*outerCombined=*/true);
break;
case llvm::omp::Directive::OMPD_loop:
case llvm::omp::Directive::OMPD_masked:
@@ -2639,14 +2615,15 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) {
const auto &beginSectionsDirective =
std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t);
- const auto &beginClauseList =
- std::get<Fortran::parser::OmpClauseList>(beginSectionsDirective.t);
+ List<Clause> beginClauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(beginSectionsDirective.t),
+ semaCtx);
// Process clauses before optional omp.parallel, so that new variables are
// allocated outside of the parallel region
mlir::Location currentLocation = converter.getCurrentLocation();
mlir::omp::SectionsClauseOps clauseOps;
- genSectionsClauses(converter, semaCtx, beginClauseList, currentLocation,
+ genSectionsClauses(converter, semaCtx, beginClauses, currentLocation,
/*clausesFromBeginSections=*/true, clauseOps);
// Parallel wrapper of PARALLEL SECTIONS construct
@@ -2655,14 +2632,15 @@ genOMP(Fortran::lower::AbstractConverter &converter,
.v;
if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
genParallelOp(converter, symTable, semaCtx, eval,
- /*genNested=*/false, currentLocation, beginClauseList,
+ /*genNested=*/false, currentLocation, beginClauses,
/*outerCombined=*/true);
} else {
const auto &endSectionsDirective =
std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t);
- const auto &endClauseList =
- std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t);
- genSectionsClauses(converter, semaCtx, endClauseList, currentLocation,
+ List<Clause> endClauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t),
+ semaCtx);
+ genSectionsClauses(converter, semaCtx, endClauses, currentLocation,
/*clausesFromBeginSections=*/false, clauseOps);
}
@@ -2678,7 +2656,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
llvm::zip(sectionBlocks.v, eval.getNestedEvaluations())) {
symTable.pushScope();
genSectionOp(converter, semaCtx, neval, /*genNested=*/true, currentLocation,
- beginClauseList);
+ beginClauses);
symTable.popScope();
firOpBuilder.restoreInsertionPoint(ip);
}
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index b9c0660aa4da8e..da3f2be73e5095 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -36,6 +36,17 @@ namespace Fortran {
namespace lower {
namespace omp {
+int64_t getCollapseValue(const List<Clause> &clauses) {
+ auto iter = llvm::find_if(clauses, [](const Clause &clause) {
+ return clause.id == llvm::omp::Clause::OMPC_collapse;
+ });
+ if (iter != clauses.end()) {
+ const auto &collapse = std::get<clause::Collapse>(iter->u);
+ return evaluate::ToInt64(collapse.v).value();
+ }
+ return 1;
+}
+
void genObjectList(const ObjectList &objects,
Fortran::lower::AbstractConverter &converter,
llvm::SmallVectorImpl<mlir::Value> &operands) {
@@ -52,25 +63,6 @@ void genObjectList(const ObjectList &objects,
}
}
-void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
- Fortran::lower::AbstractConverter &converter,
- llvm::SmallVectorImpl<mlir::Value> &operands) {
- auto addOperands = [&](Fortran::lower::SymbolRef sym) {
- const mlir::Value variable = converter.getSymbolAddress(sym);
- if (variable) {
- operands.push_back(variable);
- } else if (const auto *details =
- sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
- operands.push_back(converter.getSymbolAddress(details->symbol()));
- converter.copySymbolBinding(details->symbol(), sym);
- }
- };
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
- addOperands(*sym);
- }
-}
-
mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
std::size_t loopVarTypeSize) {
// OpenMP runtime requires 32-bit or 64-bit loop variables.
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 4074bf73987d5b..b3a9f7f30c98bd 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -58,6 +58,8 @@ void gatherFuncAndVarSyms(
const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause);
+int64_t getCollapseValue(const List<Clause> &clauses);
+
Fortran::semantics::Symbol *
getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject);
@@ -65,10 +67,6 @@ void genObjectList(const ObjectList &objects,
Fortran::lower::AbstractConverter &converter,
llvm::SmallVectorImpl<mlir::Value> &operands);
-void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
- Fortran::lower::AbstractConverter &converter,
- llvm::SmallVectorImpl<mlir::Value> &operands);
-
} // namespace omp
} // namespace lower
} // namespace Fortran
>From 065b54c4ddf2b356333269aecbee00b5a23ca1ea Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Wed, 17 Apr 2024 09:37:17 -0500
Subject: [PATCH 10/17] clang-format
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 92 +++++++++++++++----------------
1 file changed, 43 insertions(+), 49 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 6a7699aee5931e..4424788e0132e2 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1022,22 +1022,23 @@ static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) {
// Code generation functions for clauses
//===----------------------------------------------------------------------===//
-static void genCriticalDeclareClauses(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- const List<Clause> &clauses, mlir::Location loc,
- mlir::omp::CriticalClauseOps &clauseOps, llvm::StringRef name) {
+static void
+genCriticalDeclareClauses(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const List<Clause> &clauses, mlir::Location loc,
+ mlir::omp::CriticalClauseOps &clauseOps,
+ llvm::StringRef name) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processHint(clauseOps);
clauseOps.nameAttr =
mlir::StringAttr::get(converter.getFirOpBuilder().getContext(), name);
}
-static void genFlushClauses(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- const ObjectList &objects, const List<Clause> &clauses,
- mlir::Location loc, llvm::SmallVectorImpl<mlir::Value> &operandRange) {
+static void genFlushClauses(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const ObjectList &objects,
+ const List<Clause> &clauses, mlir::Location loc,
+ llvm::SmallVectorImpl<mlir::Value> &operandRange) {
if (!objects.empty())
genObjectList(objects, converter, operandRange);
@@ -1048,9 +1049,8 @@ static void genFlushClauses(
static void genLoopNestClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval,
- const List<Clause> &clauses, mlir::Location loc,
- mlir::omp::LoopNestClauseOps &clauseOps,
+ Fortran::lower::pft::Evaluation &eval, const List<Clause> &clauses,
+ mlir::Location loc, mlir::omp::LoopNestClauseOps &clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processCollapse(loc, eval, clauseOps, iv);
@@ -1069,9 +1069,9 @@ genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter,
static void genParallelClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::StatementContext &stmtCtx,
- const List<Clause> &clauses, mlir::Location loc,
- bool processReduction, mlir::omp::ParallelClauseOps &clauseOps,
+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, bool processReduction,
+ mlir::omp::ParallelClauseOps &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
@@ -1136,9 +1136,8 @@ static void genSingleClauses(Fortran::lower::AbstractConverter &converter,
static void genTargetClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::StatementContext &stmtCtx,
- const List<Clause> &clauses, mlir::Location loc,
- bool processHostOnlyClauses, bool processReduction,
+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, bool processHostOnlyClauses, bool processReduction,
mlir::omp::TargetClauseOps &clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &mapSyms,
llvm::SmallVectorImpl<mlir::Location> &mapLocs,
@@ -1708,10 +1707,11 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter,
}
template <typename OpTy>
-static OpTy genTargetEnterExitUpdateDataOp(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx, mlir::Location loc,
- const List<Clause> &clauses) {
+static OpTy
+genTargetEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ mlir::Location loc,
+ const List<Clause> &clauses) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
@@ -1852,8 +1852,7 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
static void genCompositeDistributeParallelDo(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval,
- const List<Clause> &beginClauses,
+ Fortran::lower::pft::Evaluation &eval, const List<Clause> &beginClauses,
const List<Clause> &endClauses, mlir::Location loc) {
TODO(loc, "Composite DISTRIBUTE PARALLEL DO");
}
@@ -1861,28 +1860,26 @@ static void genCompositeDistributeParallelDo(
static void genCompositeDistributeParallelDoSimd(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval,
- const List<Clause> &beginClauses,
+ Fortran::lower::pft::Evaluation &eval, const List<Clause> &beginClauses,
const List<Clause> &endClauses, mlir::Location loc) {
TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD");
}
-static void genCompositeDistributeSimd(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval,
- const List<Clause> &beginClauses,
- const List<Clause> &endClauses, mlir::Location loc) {
+static void
+genCompositeDistributeSimd(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc) {
TODO(loc, "Composite DISTRIBUTE SIMD");
}
-static void
-genCompositeDoSimd(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval,
- const List<Clause> &beginClauses,
- const List<Clause> &endClauses,
- mlir::Location loc) {
+static void genCompositeDoSimd(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses,
+ mlir::Location loc) {
ClauseProcessor cp(converter, semaCtx, beginClauses);
cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
clause::Order, clause::Safelen, clause::Simdlen>(
@@ -1903,8 +1900,7 @@ genCompositeTaskloopSimd(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
const List<Clause> &beginClauses,
- const List<Clause> &endClauses,
- mlir::Location loc) {
+ const List<Clause> &endClauses, mlir::Location loc) {
TODO(loc, "Composite TASKLOOP SIMD");
}
@@ -2351,9 +2347,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
switch (leafDir) {
case llvm::omp::Directive::OMPD_distribute_parallel_do:
// 2.9.4.3 DISTRIBUTE PARALLEL Worksharing-Loop construct.
- genCompositeDistributeParallelDo(converter, semaCtx, eval,
- beginClauses, endClauses,
- currentLocation);
+ genCompositeDistributeParallelDo(converter, semaCtx, eval, beginClauses,
+ endClauses, currentLocation);
break;
case llvm::omp::Directive::OMPD_distribute_parallel_do_simd:
// 2.9.4.4 DISTRIBUTE PARALLEL Worksharing-Loop SIMD construct.
@@ -2368,8 +2363,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
break;
case llvm::omp::Directive::OMPD_do_simd:
// 2.9.3.2 Worksharing-Loop SIMD construct.
- genCompositeDoSimd(converter, semaCtx, eval, beginClauses,
- endClauses, currentLocation);
+ genCompositeDoSimd(converter, semaCtx, eval, beginClauses, endClauses,
+ currentLocation);
break;
case llvm::omp::Directive::OMPD_taskloop_simd:
// 2.10.3 TASKLOOP SIMD construct.
@@ -2413,8 +2408,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
break;
case llvm::omp::Directive::OMPD_taskloop:
// 2.10.2 TASKLOOP construct.
- genTaskloopOp(converter, semaCtx, eval, currentLocation,
- beginClauses);
+ genTaskloopOp(converter, semaCtx, eval, currentLocation, beginClauses);
break;
case llvm::omp::Directive::OMPD_teams:
// 2.7 TEAMS construct.
>From cb7c0f8c1d929939bccbd2565cd11132c18a9687 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Wed, 17 Apr 2024 11:42:32 -0500
Subject: [PATCH 11/17] Rename test
---
.../Frontend/{OpenMPComposeTest.cpp => OpenMPCompositionTest.cpp} | 0
1 file changed, 0 insertions(+), 0 deletions(-)
rename llvm/unittests/Frontend/{OpenMPComposeTest.cpp => OpenMPCompositionTest.cpp} (100%)
diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPCompositionTest.cpp
similarity index 100%
rename from llvm/unittests/Frontend/OpenMPComposeTest.cpp
rename to llvm/unittests/Frontend/OpenMPCompositionTest.cpp
>From 8f935fbcd7ee8a572e0a5242d1b6ad5c5a70975d Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Wed, 17 Apr 2024 11:46:48 -0500
Subject: [PATCH 12/17] Finish the renaming
---
llvm/unittests/Frontend/CMakeLists.txt | 2 +-
llvm/unittests/Frontend/OpenMPCompositionTest.cpp | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/unittests/Frontend/CMakeLists.txt b/llvm/unittests/Frontend/CMakeLists.txt
index ddb6a16cbb984e..3f290b63ba6479 100644
--- a/llvm/unittests/Frontend/CMakeLists.txt
+++ b/llvm/unittests/Frontend/CMakeLists.txt
@@ -14,7 +14,7 @@ add_llvm_unittest(LLVMFrontendTests
OpenMPContextTest.cpp
OpenMPIRBuilderTest.cpp
OpenMPParsingTest.cpp
- OpenMPComposeTest.cpp
+ OpenMPCompositionTest.cpp
DEPENDS
acc_gen
diff --git a/llvm/unittests/Frontend/OpenMPCompositionTest.cpp b/llvm/unittests/Frontend/OpenMPCompositionTest.cpp
index 5e9b2c2df174a3..0b32e0d96dc84c 100644
--- a/llvm/unittests/Frontend/OpenMPCompositionTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPCompositionTest.cpp
@@ -1,4 +1,4 @@
-//===- llvm/unittests/Frontend/OpenMPComposeTest.cpp ----------------------===//
+//===- llvm/unittests/Frontend/OpenMPCompositionTest.cpp ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
>From 4593582b2a480dfffd2dceb4611cc0dec9cd7de5 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Thu, 11 Apr 2024 10:33:44 -0500
Subject: [PATCH 13/17] [LLVM][OpenMP] Implement getLeafOrCompositeConstructs
This function will break up a construct into constituent leaf and
composite constructs, e.g. if OMPD_c_d_e and OMPD_d_e are composite
constructs, then OMPD_a_b_c_d_e will be broken up into the list
{OMPD_a, OMPD_b, OMPD_c_d_e}.
---
llvm/include/llvm/Frontend/OpenMP/OMP.h | 6 ++
llvm/lib/Frontend/OpenMP/OMP.cpp | 88 ++++++++++++++++---
.../Frontend/OpenMPCompositionTest.cpp | 32 +++++++
3 files changed, 113 insertions(+), 13 deletions(-)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.h b/llvm/include/llvm/Frontend/OpenMP/OMP.h
index ec8ae68f1c2ca0..6f7a39acac1d31 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.h
@@ -16,9 +16,15 @@
#include "llvm/Frontend/OpenMP/OMP.h.inc"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
namespace llvm::omp {
ArrayRef<Directive> getLeafConstructs(Directive D);
+ArrayRef<Directive> getLeafConstructsOrSelf(Directive D);
+
+ArrayRef<Directive>
+getLeafOrCompositeConstructs(Directive D, SmallVectorImpl<Directive> &Output);
+
Directive getCompoundConstruct(ArrayRef<Directive> Parts);
bool isLeafConstruct(Directive D);
diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp
index 4b9b7037ee4ad5..2ebadf5216a084 100644
--- a/llvm/lib/Frontend/OpenMP/OMP.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMP.cpp
@@ -25,6 +25,43 @@ using namespace llvm::omp;
#define GEN_DIRECTIVES_IMPL
#include "llvm/Frontend/OpenMP/OMP.inc"
+static iterator_range<ArrayRef<Directive>::iterator>
+getFirstCompositeRange(iterator_range<ArrayRef<Directive>::iterator> Leafs) {
+ // OpenMP Spec 5.2: [17.3, 8-9]
+ // If directive-name-A and directive-name-B both correspond to loop-
+ // associated constructs then directive-name is a composite construct
+ // otherwise directive-name is a combined construct.
+ //
+ // In the list of leaf constructs, find the first loop-associated construct,
+ // this is the beginning of the range. Then, starting from the immediately
+ // following leaf construct, find the first sequence of adjacent loop-
+ // -associated constructs. The last of those is the last one of the range.
+
+ auto firstLoopAssociated =
+ [](iterator_range<ArrayRef<Directive>::iterator> List) {
+ for (auto It = List.begin(), End = List.end(); It != End; ++It) {
+ if (getDirectiveAssociation(*It) == Association::Loop)
+ return It;
+ }
+ return List.end();
+ };
+
+ auto Begin = firstLoopAssociated(Leafs);
+ if (Begin == Leafs.end())
+ return llvm::make_range(Leafs.end(), Leafs.end());
+
+ auto End =
+ firstLoopAssociated(llvm::make_range(std::next(Begin), Leafs.end()));
+ if (End == Leafs.end())
+ return llvm::make_range(Begin, std::next(Begin));
+
+ for (; End != Leafs.end(); ++End) {
+ if (getDirectiveAssociation(*End) != Association::Loop)
+ break;
+ }
+ return llvm::make_range(Begin, End);
+}
+
namespace llvm::omp {
ArrayRef<Directive> getLeafConstructs(Directive D) {
auto Idx = static_cast<std::size_t>(D);
@@ -34,6 +71,40 @@ ArrayRef<Directive> getLeafConstructs(Directive D) {
return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1]));
}
+ArrayRef<Directive> getLeafConstructsOrSelf(Directive D) {
+ if (auto Leafs = getLeafConstructs(D); !Leafs.empty())
+ return Leafs;
+ auto Idx = static_cast<size_t>(D);
+ assert(Idx < Directive_enumSize && "Invalid directive");
+ const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]];
+ // The first entry in the row is the directive itself.
+ return ArrayRef(&Row[0], &Row[0] + 1);
+}
+
+ArrayRef<Directive>
+getLeafOrCompositeConstructs(Directive D, SmallVectorImpl<Directive> &Output) {
+ using ArrayTy = ArrayRef<Directive>;
+ using IteratorTy = ArrayTy::iterator;
+ ArrayRef<Directive> Leafs = getLeafConstructsOrSelf(D);
+
+ IteratorTy Iter = Leafs.begin();
+ do {
+ auto Range = getFirstCompositeRange(llvm::make_range(Iter, Leafs.end()));
+ // All directives before the range are leaf constructs.
+ for (; Iter != Range.begin(); ++Iter)
+ Output.push_back(*Iter);
+ if (!Range.empty()) {
+ Directive Comp =
+ getCompoundConstruct(ArrayTy(Range.begin(), Range.end()));
+ assert(Comp != OMPD_unknown);
+ Output.push_back(Comp);
+ }
+ Iter = Range.end();
+ } while (Iter != Leafs.end());
+
+ return Output;
+}
+
Directive getCompoundConstruct(ArrayRef<Directive> Parts) {
if (Parts.empty())
return OMPD_unknown;
@@ -88,20 +159,11 @@ Directive getCompoundConstruct(ArrayRef<Directive> Parts) {
bool isLeafConstruct(Directive D) { return getLeafConstructs(D).empty(); }
bool isCompositeConstruct(Directive D) {
- // OpenMP Spec 5.2: [17.3, 8-9]
- // If directive-name-A and directive-name-B both correspond to loop-
- // associated constructs then directive-name is a composite construct
- llvm::ArrayRef<Directive> Leafs{getLeafConstructs(D)};
- if (Leafs.empty())
- return false;
- if (getDirectiveAssociation(Leafs.front()) != Association::Loop)
+ ArrayRef<Directive> Leafs = getLeafConstructsOrSelf(D);
+ if (Leafs.size() <= 1)
return false;
-
- size_t numLoopConstructs =
- llvm::count_if(Leafs.drop_front(), [](Directive L) {
- return getDirectiveAssociation(L) == Association::Loop;
- });
- return numLoopConstructs != 0;
+ auto Range = getFirstCompositeRange(Leafs);
+ return Range.begin() == Leafs.begin() && Range.end() == Leafs.end();
}
bool isCombinedConstruct(Directive D) {
diff --git a/llvm/unittests/Frontend/OpenMPCompositionTest.cpp b/llvm/unittests/Frontend/OpenMPCompositionTest.cpp
index 0b32e0d96dc84c..6915a0cbcaac2d 100644
--- a/llvm/unittests/Frontend/OpenMPCompositionTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPCompositionTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/Frontend/OpenMP/OMP.h"
#include "gtest/gtest.h"
@@ -38,6 +39,37 @@ TEST(Composition, GetCompoundConstruct) {
ASSERT_EQ(C6, OMPD_parallel_for_simd);
}
+TEST(Composition, GetLeafOrCompositeConstructs) {
+ SmallVector<Directive> Out1;
+ auto Ret1 = getLeafOrCompositeConstructs(
+ OMPD_target_teams_distribute_parallel_for, Out1);
+ ASSERT_EQ(Ret1, ArrayRef<Directive>(Out1));
+ ASSERT_EQ((ArrayRef<Directive>(Out1)),
+ (ArrayRef<Directive>{OMPD_target, OMPD_teams,
+ OMPD_distribute_parallel_for}));
+
+ SmallVector<Directive> Out2;
+ auto Ret2 =
+ getLeafOrCompositeConstructs(OMPD_parallel_masked_taskloop_simd, Out2);
+ ASSERT_EQ(Ret2, ArrayRef<Directive>(Out2));
+ ASSERT_EQ(
+ (ArrayRef<Directive>(Out2)),
+ (ArrayRef<Directive>{OMPD_parallel, OMPD_masked, OMPD_taskloop_simd}));
+
+ SmallVector<Directive> Out3;
+ auto Ret3 =
+ getLeafOrCompositeConstructs(OMPD_distribute_parallel_do_simd, Out3);
+ ASSERT_EQ(Ret3, ArrayRef<Directive>(Out3));
+ ASSERT_EQ((ArrayRef<Directive>(Out3)),
+ (ArrayRef<Directive>{OMPD_distribute_parallel_do_simd}));
+
+ SmallVector<Directive> Out4;
+ auto Ret4 = getLeafOrCompositeConstructs(OMPD_target_parallel_loop, Out4);
+ ASSERT_EQ(Ret4, ArrayRef<Directive>(Out4));
+ ASSERT_EQ((ArrayRef<Directive>(Out4)),
+ (ArrayRef<Directive>{OMPD_target, OMPD_parallel, OMPD_loop}));
+}
+
TEST(Composition, IsLeafConstruct) {
ASSERT_TRUE(isLeafConstruct(OMPD_loop));
ASSERT_TRUE(isLeafConstruct(OMPD_teams));
>From a626be8e9b0561c58c26ed77f92966109bbd7041 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Mon, 22 Apr 2024 09:18:15 -0500
Subject: [PATCH 14/17] Address review comments
---
llvm/lib/Frontend/OpenMP/OMP.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp
index 2ebadf5216a084..59f80838a9989e 100644
--- a/llvm/lib/Frontend/OpenMP/OMP.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMP.cpp
@@ -66,9 +66,9 @@ namespace llvm::omp {
ArrayRef<Directive> getLeafConstructs(Directive D) {
auto Idx = static_cast<std::size_t>(D);
if (Idx >= Directive_enumSize)
- return {};
+ return std::nullopt;
const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]];
- return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1]));
+ return ArrayRef(&Row[2], static_cast<int>(Row[1]));
}
ArrayRef<Directive> getLeafConstructsOrSelf(Directive D) {
>From a37ae8670ca04293b65b4787664e2b85d46030e9 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Mon, 22 Apr 2024 09:20:17 -0500
Subject: [PATCH 15/17] Revert "Address review comments"
This reverts commit a626be8e9b0561c58c26ed77f92966109bbd7041.
---
llvm/lib/Frontend/OpenMP/OMP.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp
index 59f80838a9989e..2ebadf5216a084 100644
--- a/llvm/lib/Frontend/OpenMP/OMP.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMP.cpp
@@ -66,9 +66,9 @@ namespace llvm::omp {
ArrayRef<Directive> getLeafConstructs(Directive D) {
auto Idx = static_cast<std::size_t>(D);
if (Idx >= Directive_enumSize)
- return std::nullopt;
+ return {};
const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]];
- return ArrayRef(&Row[2], static_cast<int>(Row[1]));
+ return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1]));
}
ArrayRef<Directive> getLeafConstructsOrSelf(Directive D) {
>From 02bbf90c2070fb7048c09d11cdfd6609d2966b66 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Tue, 23 Apr 2024 10:52:48 -0500
Subject: [PATCH 16/17] Clarify the interface of `getFirstCompositeRange`
---
llvm/lib/Frontend/OpenMP/OMP.cpp | 23 +++++++++++++++++------
1 file changed, 17 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp
index 85bcb862570496..ec01937575ee4a 100644
--- a/llvm/lib/Frontend/OpenMP/OMP.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMP.cpp
@@ -33,9 +33,18 @@ getFirstCompositeRange(iterator_range<ArrayRef<Directive>::iterator> Leafs) {
// otherwise directive-name is a combined construct.
//
// In the list of leaf constructs, find the first loop-associated construct,
- // this is the beginning of the range. Then, starting from the immediately
- // following leaf construct, find the first sequence of adjacent loop-
- // -associated constructs. The last of those is the last one of the range.
+ // this is the beginning of the returned range. Then, starting from the
+ // immediately following leaf construct, find the first sequence of adjacent
+ // loop-associated constructs. The last of those is the last one of the
+ // range, that is, the end of the range is one past that element.
+ // If such a sequence of adjacent loop-associated directives does not exist,
+ // return an empty range.
+ //
+ // The end of the returned range (including empty range) is intended to be
+ // a point from which the search for the next range could resume.
+ //
+ // Consequently, this function can't return a range with a single leaf
+ // construct in it.
auto firstLoopAssociated =
[](iterator_range<ArrayRef<Directive>::iterator> List) {
@@ -46,14 +55,16 @@ getFirstCompositeRange(iterator_range<ArrayRef<Directive>::iterator> Leafs) {
return List.end();
};
+ auto Empty = llvm::make_range(Leafs.end(), Leafs.end());
+
auto Begin = firstLoopAssociated(Leafs);
if (Begin == Leafs.end())
- return llvm::make_range(Leafs.end(), Leafs.end());
+ return Empty;
auto End =
firstLoopAssociated(llvm::make_range(std::next(Begin), Leafs.end()));
if (End == Leafs.end())
- return llvm::make_range(Begin, std::next(Begin));
+ return Empty;
for (; End != Leafs.end(); ++End) {
if (getDirectiveAssociation(*End) != Association::Loop)
@@ -98,8 +109,8 @@ getLeafOrCompositeConstructs(Directive D, SmallVectorImpl<Directive> &Output) {
getCompoundConstruct(ArrayTy(Range.begin(), Range.end()));
assert(Comp != OMPD_unknown);
Output.push_back(Comp);
+ Iter = Range.end();
}
- Iter = Range.end();
} while (Iter != Leafs.end());
return Output;
>From cc539dbe1119d92ed233beecb6571c2b7ea01dd6 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Tue, 23 Apr 2024 11:02:11 -0500
Subject: [PATCH 17/17] Add assertion to verify that composite construct
consumed all remaining leafs
---
llvm/lib/Frontend/OpenMP/OMP.cpp | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp
index ec01937575ee4a..c1556ff3c74d72 100644
--- a/llvm/lib/Frontend/OpenMP/OMP.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMP.cpp
@@ -110,6 +110,10 @@ getLeafOrCompositeConstructs(Directive D, SmallVectorImpl<Directive> &Output) {
assert(Comp != OMPD_unknown);
Output.push_back(Comp);
Iter = Range.end();
+ // As of now, a composite construct must contain all constituent leaf
+ // constructs from some point until the end of all constituent leaf
+ // constructs.
+ assert(Iter == Leafs.end() && "Malformed directive");
}
} while (Iter != Leafs.end());
More information about the flang-commits
mailing list