[flang-commits] [flang] ca1bd59 - [flang][OpenMP] Decompose compound constructs, do recursive lowering (#90098)
via flang-commits
flang-commits at lists.llvm.org
Mon May 13 06:09:29 PDT 2024
Author: Krzysztof Parzyszek
Date: 2024-05-13T08:09:24-05:00
New Revision: ca1bd5995f6ed934f9187305190a5abfac049173
URL: https://github.com/llvm/llvm-project/commit/ca1bd5995f6ed934f9187305190a5abfac049173
DIFF: https://github.com/llvm/llvm-project/commit/ca1bd5995f6ed934f9187305190a5abfac049173.diff
LOG: [flang][OpenMP] Decompose compound constructs, do recursive lowering (#90098)
A compound construct with a list of clauses is broken up into individual
leaf/composite constructs. Each such construct has the list of clauses
that apply to it based on the OpenMP spec.
Each lowering function (i.e. a function that generates MLIR ops) is now
responsible for generating its body as described below.
Functions that receive AST nodes extract the construct, and the clauses
from the node. They then create a work queue consisting of individual
constructs, and invoke a common dispatch function to process (lower) the
queue.
The dispatch function examines the current position in the queue, and
invokes the appropriate lowering function. Each lowering function
receives the queue as well, and once it needs to generate its body, it
either invokes the dispatch function on the rest of the queue (if any),
or processes nested evaluations if the work queue is at the end.
Added:
flang/lib/Lower/OpenMP/Decomposer.cpp
flang/lib/Lower/OpenMP/Decomposer.h
llvm/include/llvm/Frontend/OpenMP/ConstructCompositionT.h
llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h
llvm/unittests/Frontend/OpenMPDecompositionTest.cpp
Modified:
flang/lib/Lower/CMakeLists.txt
flang/lib/Lower/OpenMP/Clauses.cpp
flang/lib/Lower/OpenMP/Clauses.h
flang/lib/Lower/OpenMP/OpenMP.cpp
flang/lib/Lower/OpenMP/Utils.cpp
flang/lib/Lower/OpenMP/Utils.h
flang/test/Lower/OpenMP/default-clause-byref.f90
flang/test/Lower/OpenMP/default-clause.f90
flang/test/Lower/OpenMP/parallel-lastprivate-clause-scalar.f90
llvm/include/llvm/Frontend/OpenMP/ClauseT.h
llvm/unittests/Frontend/CMakeLists.txt
Removed:
################################################################################
diff --git a/flang/lib/Lower/CMakeLists.txt b/flang/lib/Lower/CMakeLists.txt
index f92d1a2bc7de1..1546409752e78 100644
--- a/flang/lib/Lower/CMakeLists.txt
+++ b/flang/lib/Lower/CMakeLists.txt
@@ -27,6 +27,7 @@ add_flang_library(FortranLower
OpenMP/ClauseProcessor.cpp
OpenMP/Clauses.cpp
OpenMP/DataSharingProcessor.cpp
+ OpenMP/Decomposer.cpp
OpenMP/OpenMP.cpp
OpenMP/ReductionProcessor.cpp
OpenMP/Utils.cpp
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index 97337cfc08c72..87370c92964a5 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -1227,4 +1227,27 @@ List<Clause> makeClauses(const parser::OmpClauseList &clauses,
return makeClause(s, semaCtx);
});
}
+
+bool transferLocations(const List<Clause> &from, List<Clause> &to) {
+ bool allDone = true;
+
+ for (Clause &clause : to) {
+ if (!clause.source.empty())
+ continue;
+ auto found =
+ llvm::find_if(from, [&](const Clause &c) { return c.id == clause.id; });
+ // This is not completely accurate, but should be good enough for now.
+ // It can be improved in the future if necessary, but in cases of
+ // synthesized clauses getting accurate location may be impossible.
+ if (found != from.end()) {
+ clause.source = found->source;
+ } else {
+ // Found a clause that won't have "source".
+ allDone = false;
+ }
+ }
+
+ return allDone;
+}
+
} // namespace Fortran::lower::omp
diff --git a/flang/lib/Lower/OpenMP/Clauses.h b/flang/lib/Lower/OpenMP/Clauses.h
index 3e776425c733e..407579319279e 100644
--- a/flang/lib/Lower/OpenMP/Clauses.h
+++ b/flang/lib/Lower/OpenMP/Clauses.h
@@ -23,11 +23,15 @@
namespace Fortran::lower::omp {
using namespace Fortran;
-using SomeType = evaluate::SomeType;
using SomeExpr = semantics::SomeExpr;
using MaybeExpr = semantics::MaybeExpr;
-using TypeTy = SomeType;
+// evaluate::SomeType doesn't provide == operation. It's not really used in
+// flang's clauses so far, so a trivial implementation is sufficient.
+struct TypeTy : public evaluate::SomeType {
+ bool operator==(const TypeTy &t) const { return true; }
+};
+
using IdTy = semantics::Symbol *;
using ExprTy = SomeExpr;
@@ -222,6 +226,8 @@ using When = tomp::clause::WhenT<TypeTy, IdTy, ExprTy>;
using Write = tomp::clause::WriteT<TypeTy, IdTy, ExprTy>;
} // namespace clause
+using tomp::type::operator==;
+
struct CancellationConstructType {
using EmptyTrait = std::true_type;
};
@@ -244,6 +250,7 @@ using ClauseBase = tomp::ClauseT<TypeTy, IdTy, ExprTy,
MemoryOrder, Threadprivate>;
struct Clause : public ClauseBase {
+ // "source" will be ignored by tomp::type::operator==.
parser::CharBlock source;
};
@@ -258,6 +265,8 @@ Clause makeClause(const Fortran::parser::OmpClause &cls,
List<Clause> makeClauses(const parser::OmpClauseList &clauses,
semantics::SemanticsContext &semaCtx);
+
+bool transferLocations(const List<Clause> &from, List<Clause> &to);
} // namespace Fortran::lower::omp
#endif // FORTRAN_LOWER_OPENMP_CLAUSES_H
diff --git a/flang/lib/Lower/OpenMP/Decomposer.cpp b/flang/lib/Lower/OpenMP/Decomposer.cpp
new file mode 100644
index 0000000000000..e6897cb81e947
--- /dev/null
+++ b/flang/lib/Lower/OpenMP/Decomposer.cpp
@@ -0,0 +1,126 @@
+//===-- Decomposer.cpp -- Compound directive decomposition ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#include "Decomposer.h"
+
+#include "Clauses.h"
+#include "Utils.h"
+#include "flang/Lower/PFTBuilder.h"
+#include "flang/Semantics/semantics.h"
+#include "flang/Tools/CrossToolHelpers.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Frontend/OpenMP/ClauseT.h"
+#include "llvm/Frontend/OpenMP/ConstructCompositionT.h"
+#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
+#include "llvm/Frontend/OpenMP/OMP.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <optional>
+#include <utility>
+#include <variant>
+
+using namespace Fortran;
+
+namespace {
+using namespace Fortran::lower::omp;
+
+struct ConstructDecomposition {
+ ConstructDecomposition(mlir::ModuleOp modOp,
+ semantics::SemanticsContext &semaCtx,
+ lower::pft::Evaluation &ev,
+ llvm::omp::Directive compound,
+ const List<Clause> &clauses)
+ : semaCtx(semaCtx), mod(modOp), eval(ev) {
+ tomp::ConstructDecompositionT decompose(getOpenMPVersionAttribute(modOp),
+ *this, compound,
+ llvm::ArrayRef(clauses));
+ output = std::move(decompose.output);
+ }
+
+ // Given an object, return its base object if one exists.
+ std::optional<Object> getBaseObject(const Object &object) {
+ return lower::omp::getBaseObject(object, semaCtx);
+ }
+
+ // Return the iteration variable of the associated loop if any.
+ std::optional<Object> getLoopIterVar() {
+ if (semantics::Symbol *symbol = getIterationVariableSymbol(eval))
+ return Object{symbol, /*designator=*/{}};
+ return std::nullopt;
+ }
+
+ semantics::SemanticsContext &semaCtx;
+ mlir::ModuleOp mod;
+ lower::pft::Evaluation &eval;
+ List<UnitConstruct> output;
+};
+} // namespace
+
+static UnitConstruct mergeConstructs(uint32_t version,
+ llvm::ArrayRef<UnitConstruct> units) {
+ tomp::ConstructCompositionT compose(version, units);
+ return compose.merged;
+}
+
+namespace Fortran::lower::omp {
+LLVM_DUMP_METHOD llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ const UnitConstruct &uc) {
+ os << llvm::omp::getOpenMPDirectiveName(uc.id);
+ for (auto [index, clause] : llvm::enumerate(uc.clauses)) {
+ os << (index == 0 ? '\t' : ' ');
+ os << llvm::omp::getOpenMPClauseName(clause.id);
+ }
+ return os;
+}
+
+ConstructQueue buildConstructQueue(
+ mlir::ModuleOp modOp, Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, const parser::CharBlock &source,
+ llvm::omp::Directive compound, const List<Clause> &clauses) {
+
+ List<UnitConstruct> constructs;
+
+ ConstructDecomposition decompose(modOp, semaCtx, eval, compound, clauses);
+ assert(!decompose.output.empty() && "Construct decomposition failed");
+
+ llvm::SmallVector<llvm::omp::Directive> loweringUnits;
+ std::ignore =
+ llvm::omp::getLeafOrCompositeConstructs(compound, loweringUnits);
+ uint32_t version = getOpenMPVersionAttribute(modOp);
+
+ int leafIndex = 0;
+ for (llvm::omp::Directive dir_id : loweringUnits) {
+ llvm::ArrayRef<llvm::omp::Directive> leafsOrSelf =
+ llvm::omp::getLeafConstructsOrSelf(dir_id);
+ size_t numLeafs = leafsOrSelf.size();
+
+ llvm::ArrayRef<UnitConstruct> toMerge{&decompose.output[leafIndex],
+ numLeafs};
+ auto &uc = constructs.emplace_back(mergeConstructs(version, toMerge));
+
+ if (!transferLocations(clauses, uc.clauses)) {
+ // If some clauses are left without source information, use the
+ // directive's source.
+ for (auto &clause : uc.clauses) {
+ if (clause.source.empty())
+ clause.source = source;
+ }
+ }
+ leafIndex += numLeafs;
+ }
+
+ return constructs;
+}
+} // namespace Fortran::lower::omp
diff --git a/flang/lib/Lower/OpenMP/Decomposer.h b/flang/lib/Lower/OpenMP/Decomposer.h
new file mode 100644
index 0000000000000..f42d8f5c17408
--- /dev/null
+++ b/flang/lib/Lower/OpenMP/Decomposer.h
@@ -0,0 +1,51 @@
+//===-- Decomposer.h -- Compound directive decomposition ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef FORTRAN_LOWER_OPENMP_DECOMPOSER_H
+#define FORTRAN_LOWER_OPENMP_DECOMPOSER_H
+
+#include "Clauses.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "llvm/Frontend/OpenMP/ConstructCompositionT.h"
+#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
+#include "llvm/Frontend/OpenMP/OMP.h"
+#include "llvm/Support/Compiler.h"
+
+namespace llvm {
+class raw_ostream;
+}
+
+namespace Fortran {
+namespace semantics {
+class SemanticsContext;
+}
+namespace lower::pft {
+struct Evaluation;
+}
+} // namespace Fortran
+
+namespace Fortran::lower::omp {
+using UnitConstruct = tomp::DirectiveWithClauses<lower::omp::Clause>;
+using ConstructQueue = List<UnitConstruct>;
+
+LLVM_DUMP_METHOD llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ const UnitConstruct &uc);
+
+// Given a potentially compound construct with a list of clauses that
+// apply to it, break it up into individual sub-constructs each with
+// the subset of applicable clauses (plus implicit clauses, if any).
+// From that create a work queue where each work item corresponds to
+// the sub-construct with its clauses.
+ConstructQueue buildConstructQueue(mlir::ModuleOp modOp,
+ semantics::SemanticsContext &semaCtx,
+ lower::pft::Evaluation &eval,
+ const parser::CharBlock &source,
+ llvm::omp::Directive compound,
+ const List<Clause> &clauses);
+} // namespace Fortran::lower::omp
+
+#endif // FORTRAN_LOWER_OPENMP_DECOMPOSER_H
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index f23902d6a8239..eaf4b5f997ff7 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -15,6 +15,7 @@
#include "ClauseProcessor.h"
#include "Clauses.h"
#include "DataSharingProcessor.h"
+#include "Decomposer.h"
#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
@@ -44,6 +45,13 @@ using namespace Fortran::lower::omp;
// Code generation helper functions
//===----------------------------------------------------------------------===//
+static void genOMPDispatch(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::SymMap &symTable,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ mlir::Location loc, const ConstructQueue &queue,
+ ConstructQueue::iterator item);
+
static Fortran::lower::pft::Evaluation *
getCollapsedLoopEval(Fortran::lower::pft::Evaluation &eval, int collapseValue) {
// Return the Evaluation of the innermost collapsed loop, or the current one
@@ -460,81 +468,6 @@ markDeclareTarget(mlir::Operation *op,
declareTargetOp.setDeclareTarget(deviceType, captureClause);
}
-/// Split a combined directive into an outer leaf directive and the (possibly
-/// combined) rest of the combined directive. Composite directives and
-/// non-compound directives are not split, in which case it will return the
-/// input directive as its first output and an empty value as its second output.
-static std::pair<llvm::omp::Directive, std::optional<llvm::omp::Directive>>
-splitCombinedDirective(llvm::omp::Directive dir) {
- using D = llvm::omp::Directive;
- switch (dir) {
- case D::OMPD_masked_taskloop:
- return {D::OMPD_masked, D::OMPD_taskloop};
- case D::OMPD_masked_taskloop_simd:
- return {D::OMPD_masked, D::OMPD_taskloop_simd};
- case D::OMPD_master_taskloop:
- return {D::OMPD_master, D::OMPD_taskloop};
- case D::OMPD_master_taskloop_simd:
- return {D::OMPD_master, D::OMPD_taskloop_simd};
- case D::OMPD_parallel_do:
- return {D::OMPD_parallel, D::OMPD_do};
- case D::OMPD_parallel_do_simd:
- return {D::OMPD_parallel, D::OMPD_do_simd};
- case D::OMPD_parallel_masked:
- return {D::OMPD_parallel, D::OMPD_masked};
- case D::OMPD_parallel_masked_taskloop:
- return {D::OMPD_parallel, D::OMPD_masked_taskloop};
- case D::OMPD_parallel_masked_taskloop_simd:
- return {D::OMPD_parallel, D::OMPD_masked_taskloop_simd};
- case D::OMPD_parallel_master:
- return {D::OMPD_parallel, D::OMPD_master};
- case D::OMPD_parallel_master_taskloop:
- return {D::OMPD_parallel, D::OMPD_master_taskloop};
- case D::OMPD_parallel_master_taskloop_simd:
- return {D::OMPD_parallel, D::OMPD_master_taskloop_simd};
- case D::OMPD_parallel_sections:
- return {D::OMPD_parallel, D::OMPD_sections};
- case D::OMPD_parallel_workshare:
- return {D::OMPD_parallel, D::OMPD_workshare};
- case D::OMPD_target_parallel:
- return {D::OMPD_target, D::OMPD_parallel};
- case D::OMPD_target_parallel_do:
- return {D::OMPD_target, D::OMPD_parallel_do};
- case D::OMPD_target_parallel_do_simd:
- return {D::OMPD_target, D::OMPD_parallel_do_simd};
- case D::OMPD_target_simd:
- return {D::OMPD_target, D::OMPD_simd};
- case D::OMPD_target_teams:
- return {D::OMPD_target, D::OMPD_teams};
- case D::OMPD_target_teams_distribute:
- return {D::OMPD_target, D::OMPD_teams_distribute};
- case D::OMPD_target_teams_distribute_parallel_do:
- return {D::OMPD_target, D::OMPD_teams_distribute_parallel_do};
- case D::OMPD_target_teams_distribute_parallel_do_simd:
- return {D::OMPD_target, D::OMPD_teams_distribute_parallel_do_simd};
- case D::OMPD_target_teams_distribute_simd:
- return {D::OMPD_target, D::OMPD_teams_distribute_simd};
- case D::OMPD_teams_distribute:
- return {D::OMPD_teams, D::OMPD_distribute};
- case D::OMPD_teams_distribute_parallel_do:
- return {D::OMPD_teams, D::OMPD_distribute_parallel_do};
- case D::OMPD_teams_distribute_parallel_do_simd:
- return {D::OMPD_teams, D::OMPD_distribute_parallel_do_simd};
- case D::OMPD_teams_distribute_simd:
- return {D::OMPD_teams, D::OMPD_distribute_simd};
- case D::OMPD_parallel_loop:
- return {D::OMPD_parallel, D::OMPD_loop};
- case D::OMPD_target_parallel_loop:
- return {D::OMPD_target, D::OMPD_parallel_loop};
- case D::OMPD_target_teams_loop:
- return {D::OMPD_target, D::OMPD_teams_loop};
- case D::OMPD_teams_loop:
- return {D::OMPD_teams, D::OMPD_loop};
- default:
- return {dir, std::nullopt};
- }
-}
-
//===----------------------------------------------------------------------===//
// Op body generation helper structures and functions
//===----------------------------------------------------------------------===//
@@ -555,11 +488,6 @@ struct OpWithBodyGenInfo {
: converter(converter), symTable(symTable), semaCtx(semaCtx), loc(loc),
eval(eval), dir(dir) {}
- OpWithBodyGenInfo &setGenNested(bool value) {
- genNested = value;
- return *this;
- }
-
OpWithBodyGenInfo &setOuterCombined(bool value) {
outerCombined = value;
return *this;
@@ -600,8 +528,6 @@ struct OpWithBodyGenInfo {
Fortran::lower::pft::Evaluation &eval;
/// [in] leaf directive for which to generate the op body.
llvm::omp::Directive dir;
- /// [in] whether to generate FIR for nested evaluations
- bool genNested = true;
/// [in] is this an outer operation - prevents privatization.
bool outerCombined = false;
/// [in] list of clauses to process.
@@ -620,9 +546,13 @@ struct OpWithBodyGenInfo {
/// Create the body (block) for an OpenMP Operation.
///
-/// \param [in] op - the operation the body belongs to.
-/// \param [in] info - options controlling code-gen for the construction.
-static void createBodyOfOp(mlir::Operation &op, OpWithBodyGenInfo &info) {
+/// \param [in] op - the operation the body belongs to.
+/// \param [in] info - options controlling code-gen for the construction.
+/// \param [in] queue - work queue with nested constructs.
+/// \param [in] item - item in the queue to generate body for.
+static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
+ const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
fir::FirOpBuilder &firOpBuilder = info.converter.getFirOpBuilder();
auto insertMarker = [](fir::FirOpBuilder &builder) {
@@ -678,7 +608,10 @@ static void createBodyOfOp(mlir::Operation &op, OpWithBodyGenInfo &info) {
}
}
- if (info.genNested) {
+ if (ConstructQueue::iterator next = std::next(item); next != queue.end()) {
+ genOMPDispatch(info.converter, info.symTable, info.semaCtx, info.eval,
+ info.loc, queue, next);
+ } else {
// genFIR(Evaluation&) tries to patch up unterminated blocks, causing
// a lot of complications for our approach if the terminator generation
// is delayed past this point. Insert a temporary terminator here, then
@@ -769,11 +702,12 @@ static void genBodyOfTargetDataOp(
Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::omp::TargetDataOp &dataOp, llvm::ArrayRef<mlir::Type> useDeviceTypes,
+ Fortran::lower::pft::Evaluation &eval, mlir::omp::TargetDataOp &dataOp,
+ llvm::ArrayRef<mlir::Type> useDeviceTypes,
llvm::ArrayRef<mlir::Location> useDeviceLocs,
llvm::ArrayRef<const Fortran::semantics::Symbol *> useDeviceSymbols,
- const mlir::Location ¤tLocation) {
+ const mlir::Location ¤tLocation, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Region ®ion = dataOp.getRegion();
@@ -826,8 +760,13 @@ static void genBodyOfTargetDataOp(
// Set the insertion point after the marker.
firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
- if (genNested)
+
+ if (ConstructQueue::iterator next = std::next(item); next != queue.end()) {
+ genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
+ next);
+ } else {
genNestedEvaluations(converter, eval);
+ }
}
// This functions creates a block for the body of the targetOp's region. It adds
@@ -836,12 +775,13 @@ static void
genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
+ Fortran::lower::pft::Evaluation &eval,
mlir::omp::TargetOp &targetOp,
llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSyms,
llvm::ArrayRef<mlir::Location> mapSymLocs,
llvm::ArrayRef<mlir::Type> mapSymTypes,
- const mlir::Location ¤tLocation) {
+ const mlir::Location ¤tLocation,
+ const ConstructQueue &queue, ConstructQueue::iterator item) {
assert(mapSymTypes.size() == mapSymLocs.size());
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -983,15 +923,22 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
// Create the insertion point after the marker.
firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
- if (genNested)
+
+ if (ConstructQueue::iterator next = std::next(item); next != queue.end()) {
+ genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
+ next);
+ } else {
genNestedEvaluations(converter, eval);
+ }
}
template <typename OpTy, typename... Args>
-static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) {
+static OpTy genOpWithBody(const OpWithBodyGenInfo &info,
+ const ConstructQueue &queue,
+ ConstructQueue::iterator item, Args &&...args) {
auto op = info.converter.getFirOpBuilder().create<OpTy>(
info.loc, std::forward<Args>(args)...);
- createBodyOfOp(*op, info);
+ createBodyOfOp(*op, info, queue, item);
return op;
}
@@ -1276,7 +1223,8 @@ static mlir::omp::BarrierOp
genBarrierOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, mlir::Location loc) {
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const ConstructQueue &queue, ConstructQueue::iterator item) {
return converter.getFirOpBuilder().create<mlir::omp::BarrierOp>(loc);
}
@@ -1284,8 +1232,9 @@ static mlir::omp::CriticalOp
genCriticalOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc, const List<Clause> &clauses,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item,
const std::optional<Fortran::parser::Name> &name) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::FlatSymbolRefAttr nameAttr;
@@ -1308,17 +1257,17 @@ genCriticalOp(Fortran::lower::AbstractConverter &converter,
return genOpWithBody<mlir::omp::CriticalOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
- llvm::omp::Directive::OMPD_critical)
- .setGenNested(genNested),
- nameAttr);
+ llvm::omp::Directive::OMPD_critical),
+ queue, item, nameAttr);
}
static mlir::omp::DistributeOp
genDistributeOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc, const List<Clause> &clauses) {
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
TODO(loc, "Distribute construct");
return nullptr;
}
@@ -1328,7 +1277,8 @@ genFlushOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const ObjectList &objects, const List<Clause> &clauses) {
+ const ObjectList &objects, const List<Clause> &clauses,
+ const ConstructQueue &queue, ConstructQueue::iterator item) {
llvm::SmallVector<mlir::Value> operandRange;
genFlushClauses(converter, semaCtx, objects, clauses, loc, operandRange);
@@ -1340,12 +1290,13 @@ static mlir::omp::MasterOp
genMasterOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc) {
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
return genOpWithBody<mlir::omp::MasterOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
- llvm::omp::Directive::OMPD_master)
- .setGenNested(genNested));
+ llvm::omp::Directive::OMPD_master),
+ queue, item);
}
static mlir::omp::OrderedOp
@@ -1353,7 +1304,8 @@ genOrderedOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const List<Clause> &clauses) {
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
TODO(loc, "OMPD_ordered");
return nullptr;
}
@@ -1362,25 +1314,25 @@ static mlir::omp::OrderedRegionOp
genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc, const List<Clause> &clauses) {
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
mlir::omp::OrderedRegionClauseOps clauseOps;
genOrderedRegionClauses(converter, semaCtx, clauses, loc, clauseOps);
return genOpWithBody<mlir::omp::OrderedRegionOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
- llvm::omp::Directive::OMPD_ordered)
- .setGenNested(genNested),
- clauseOps);
+ llvm::omp::Directive::OMPD_ordered),
+ queue, item, clauseOps);
}
static mlir::omp::ParallelOp
genParallelOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc, const List<Clause> &clauses,
- bool outerCombined = false) {
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item, bool outerCombined = false) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
mlir::omp::ParallelClauseOps clauseOps;
@@ -1399,14 +1351,14 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
OpWithBodyGenInfo genInfo =
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_parallel)
- .setGenNested(genNested)
.setOuterCombined(outerCombined)
.setClauses(&clauses)
.setReductions(&reductionSyms, &reductionTypes)
.setGenRegionEntryCb(reductionCallback);
if (!enableDelayedPrivatization)
- return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
+ return genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item,
+ clauseOps);
bool privatize = !outerCombined;
DataSharingProcessor dsp(converter, semaCtx, clauses, eval,
@@ -1454,19 +1406,23 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
};
genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp);
- return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
+ return genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
}
static mlir::omp::SectionOp
genSectionOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc) {
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
+ // Currently only private/firstprivate clause is handled, and
+ // all privatization is done within `omp.section` operations.
return genOpWithBody<mlir::omp::SectionOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_section)
- .setGenNested(genNested));
+ .setClauses(&clauses),
+ queue, item);
}
static mlir::omp::SectionsOp
@@ -1474,12 +1430,77 @@ genSectionsOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const mlir::omp::SectionsClauseOps &clauseOps) {
- return genOpWithBody<mlir::omp::SectionsOp>(
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
+ mlir::omp::SectionsClauseOps clauseOps;
+ genSectionsClauses(converter, semaCtx, clauses, loc, clauseOps);
+
+ auto &builder = converter.getFirOpBuilder();
+
+ // Insert privatizations before SECTIONS
+ symTable.pushScope();
+ DataSharingProcessor dsp(converter, semaCtx, clauses, eval);
+ dsp.processStep1();
+
+ List<Clause> nonDsaClauses;
+ List<const clause::Lastprivate *> lastprivates;
+
+ for (const Clause &clause : clauses) {
+ if (clause.id == llvm::omp::Clause::OMPC_lastprivate) {
+ lastprivates.push_back(&std::get<clause::Lastprivate>(clause.u));
+ } else {
+ switch (clause.id) {
+ case llvm::omp::Clause::OMPC_firstprivate:
+ case llvm::omp::Clause::OMPC_private:
+ case llvm::omp::Clause::OMPC_shared:
+ break;
+ default:
+ nonDsaClauses.push_back(clause);
+ }
+ }
+ }
+
+ // SECTIONS construct.
+ mlir::omp::SectionsOp sectionsOp = genOpWithBody<mlir::omp::SectionsOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_sections)
- .setGenNested(false),
- clauseOps);
+ .setClauses(&nonDsaClauses),
+ queue, item, clauseOps);
+
+ if (!lastprivates.empty()) {
+ mlir::Region §ionsBody = sectionsOp.getRegion();
+ assert(sectionsBody.hasOneBlock());
+ mlir::Block &body = sectionsBody.front();
+
+ auto lastSectionOp = llvm::find_if(
+ llvm::reverse(body.getOperations()), [](const mlir::Operation &op) {
+ return llvm::isa<mlir::omp::SectionOp>(op);
+ });
+ assert(lastSectionOp != body.rend());
+
+ for (const clause::Lastprivate *lastp : lastprivates) {
+ builder.setInsertionPoint(
+ lastSectionOp->getRegion(0).back().getTerminator());
+ mlir::OpBuilder::InsertPoint insp = builder.saveInsertionPoint();
+ const auto &objList = std::get<ObjectList>(lastp->t);
+ for (const Object &object : objList) {
+ Fortran::semantics::Symbol *sym = object.id();
+ converter.copyHostAssociateVar(*sym, &insp);
+ }
+ }
+ }
+
+ // Perform DataSharingProcessor's step2 out of SECTIONS
+ builder.setInsertionPointAfter(sectionsOp.getOperation());
+ dsp.processStep2(sectionsOp, false);
+ // Emit implicit barrier to synchronize threads and avoid data
+ // races on post-update of lastprivate variables when `nowait`
+ // clause is present.
+ if (clauseOps.nowaitAttr && !lastprivates.empty())
+ builder.create<mlir::omp::BarrierOp>(loc);
+
+ symTable.popScope();
+ return sectionsOp;
}
static mlir::omp::SimdOp
@@ -1487,7 +1508,8 @@ genSimdOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const List<Clause> &clauses) {
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
DataSharingProcessor dsp(converter, semaCtx, clauses, eval);
dsp.processStep1();
@@ -1522,7 +1544,8 @@ genSimdOp(Fortran::lower::AbstractConverter &converter,
*nestedEval, llvm::omp::Directive::OMPD_simd)
.setClauses(&clauses)
.setDataSharingProcessor(&dsp)
- .setGenRegionEntryCb(ivCallback));
+ .setGenRegionEntryCb(ivCallback),
+ queue, item);
return simdOp;
}
@@ -1531,26 +1554,26 @@ static mlir::omp::SingleOp
genSingleOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc, const List<Clause> &clauses) {
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
mlir::omp::SingleClauseOps clauseOps;
genSingleClauses(converter, semaCtx, clauses, loc, clauseOps);
return genOpWithBody<mlir::omp::SingleOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_single)
- .setGenNested(genNested)
.setClauses(&clauses),
- clauseOps);
+ queue, item, clauseOps);
}
static mlir::omp::TargetOp
genTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc, const List<Clause> &clauses,
- bool outerCombined = false) {
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item, bool outerCombined = false) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
@@ -1657,8 +1680,8 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap);
auto targetOp = firOpBuilder.create<mlir::omp::TargetOp>(loc, clauseOps);
- genBodyOfTargetOp(converter, symTable, semaCtx, eval, genNested, targetOp,
- mapSyms, mapLocs, mapTypes, loc);
+ genBodyOfTargetOp(converter, symTable, semaCtx, eval, targetOp, mapSyms,
+ mapLocs, mapTypes, loc, queue, item);
return targetOp;
}
@@ -1666,8 +1689,9 @@ static mlir::omp::TargetDataOp
genTargetDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc, const List<Clause> &clauses) {
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
Fortran::lower::StatementContext stmtCtx;
mlir::omp::TargetDataClauseOps clauseOps;
llvm::SmallVector<mlir::Type> useDeviceTypes;
@@ -1679,9 +1703,9 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter,
auto targetDataOp =
converter.getFirOpBuilder().create<mlir::omp::TargetDataOp>(loc,
clauseOps);
- genBodyOfTargetDataOp(converter, symTable, semaCtx, eval, genNested,
- targetDataOp, useDeviceTypes, useDeviceLocs,
- useDeviceSyms, loc);
+ genBodyOfTargetDataOp(converter, symTable, semaCtx, eval, targetDataOp,
+ useDeviceTypes, useDeviceLocs, useDeviceSyms, loc,
+ queue, item);
return targetDataOp;
}
@@ -1690,8 +1714,9 @@ static OpTy
genTargetEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- mlir::Location loc,
- const List<Clause> &clauses) {
+ mlir::Location loc, const List<Clause> &clauses,
+ const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
@@ -1718,8 +1743,9 @@ static mlir::omp::TaskOp
genTaskOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc, const List<Clause> &clauses) {
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
Fortran::lower::StatementContext stmtCtx;
mlir::omp::TaskClauseOps clauseOps;
genTaskClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps);
@@ -1727,26 +1753,25 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
return genOpWithBody<mlir::omp::TaskOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_task)
- .setGenNested(genNested)
.setClauses(&clauses),
- clauseOps);
+ queue, item, clauseOps);
}
static mlir::omp::TaskgroupOp
genTaskgroupOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc, const List<Clause> &clauses) {
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
mlir::omp::TaskgroupClauseOps clauseOps;
genTaskgroupClauses(converter, semaCtx, clauses, loc, clauseOps);
return genOpWithBody<mlir::omp::TaskgroupOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_taskgroup)
- .setGenNested(genNested)
.setClauses(&clauses),
- clauseOps);
+ queue, item, clauseOps);
}
static mlir::omp::TaskloopOp
@@ -1754,7 +1779,8 @@ genTaskloopOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const List<Clause> &clauses) {
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
TODO(loc, "Taskloop construct");
}
@@ -1763,7 +1789,8 @@ genTaskwaitOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const List<Clause> &clauses) {
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
mlir::omp::TaskwaitClauseOps clauseOps;
genTaskwaitClauses(converter, semaCtx, clauses, loc, clauseOps);
return converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(loc,
@@ -1774,7 +1801,8 @@ static mlir::omp::TaskyieldOp
genTaskyieldOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, mlir::Location loc) {
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const ConstructQueue &queue, ConstructQueue::iterator item) {
return converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(loc);
}
@@ -1782,9 +1810,9 @@ static mlir::omp::TeamsOp
genTeamsOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc, const List<Clause> &clauses,
- bool outerCombined = false) {
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item, bool outerCombined = false) {
Fortran::lower::StatementContext stmtCtx;
mlir::omp::TeamsClauseOps clauseOps;
genTeamsClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps);
@@ -1792,10 +1820,9 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
return genOpWithBody<mlir::omp::TeamsOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_teams)
- .setGenNested(genNested)
.setOuterCombined(outerCombined)
.setClauses(&clauses),
- clauseOps);
+ queue, item, clauseOps);
}
static mlir::omp::WsloopOp
@@ -1803,7 +1830,8 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const List<Clause> &clauses) {
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
DataSharingProcessor dsp(converter, semaCtx, clauses, eval);
dsp.processStep1();
@@ -1844,7 +1872,8 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
.setClauses(&clauses)
.setDataSharingProcessor(&dsp)
.setReductions(&reductionSyms, &reductionTypes)
- .setGenRegionEntryCb(ivCallback));
+ .setGenRegionEntryCb(ivCallback),
+ queue, item);
return wsloopOp;
}
@@ -1852,13 +1881,13 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
// Code generation functions for composite constructs
//===----------------------------------------------------------------------===//
-static void
-genCompositeDistributeParallelDo(Fortran::lower::AbstractConverter &converter,
- Fortran::lower::SymMap &symTable,
- Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval,
- const List<Clause> &clauses,
- mlir::Location loc) {
+static void genCompositeDistributeParallelDo(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::SymMap &symTable,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
TODO(loc, "Composite DISTRIBUTE PARALLEL DO");
}
@@ -1866,8 +1895,9 @@ static void genCompositeDistributeParallelDoSimd(
Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::pft::Evaluation &eval, const List<Clause> &clauses,
- mlir::Location loc) {
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const List<Clause> &clauses, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD");
}
@@ -1876,7 +1906,9 @@ genCompositeDistributeSimd(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const List<Clause> &clauses, mlir::Location loc) {
+ mlir::Location loc, const List<Clause> &clauses,
+ const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
TODO(loc, "Composite DISTRIBUTE SIMD");
}
@@ -1884,8 +1916,9 @@ static void genCompositeDoSimd(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const List<Clause> &clauses,
- mlir::Location loc) {
+ mlir::Location loc, const List<Clause> &clauses,
+ const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
clause::Order, clause::Safelen, clause::Simdlen>(
@@ -1898,7 +1931,7 @@ static void genCompositeDoSimd(Fortran::lower::AbstractConverter &converter,
// When support for vectorization is enabled, then we need to add handling of
// if clause. Currently if clause can be skipped because we always assume
// SIMD length = 1.
- genWsloopOp(converter, symTable, semaCtx, eval, loc, clauses);
+ genWsloopOp(converter, symTable, semaCtx, eval, loc, clauses, queue, item);
}
static void
@@ -1906,10 +1939,128 @@ genCompositeTaskloopSimd(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const List<Clause> &clauses, mlir::Location loc) {
+ mlir::Location loc, const List<Clause> &clauses,
+ const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
TODO(loc, "Composite TASKLOOP SIMD");
}
+//===----------------------------------------------------------------------===//
+// Dispatch
+//===----------------------------------------------------------------------===//
+
+static void genOMPDispatch(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::SymMap &symTable,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ mlir::Location loc, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
+ assert(item != queue.end());
+ const List<Clause> &clauses = item->clauses;
+
+ switch (llvm::omp::Directive dir = item->id) {
+ case llvm::omp::Directive::OMPD_distribute:
+ genDistributeOp(converter, symTable, semaCtx, eval, loc, clauses, queue,
+ item);
+ break;
+ case llvm::omp::Directive::OMPD_do:
+ genWsloopOp(converter, symTable, semaCtx, eval, loc, clauses, queue, item);
+ break;
+ case llvm::omp::Directive::OMPD_loop:
+ case llvm::omp::Directive::OMPD_masked:
+ case llvm::omp::Directive::OMPD_tile:
+ case llvm::omp::Directive::OMPD_unroll:
+ TODO(loc, "Unhandled loop directive (" +
+ llvm::omp::getOpenMPDirectiveName(dir) + ")");
+ break;
+ case llvm::omp::Directive::OMPD_master:
+ genMasterOp(converter, symTable, semaCtx, eval, loc, clauses, queue, item);
+ break;
+ case llvm::omp::Directive::OMPD_ordered:
+ genOrderedRegionOp(converter, symTable, semaCtx, eval, loc, clauses, queue,
+ item);
+ break;
+ case llvm::omp::Directive::OMPD_parallel:
+ genParallelOp(converter, symTable, semaCtx, eval, loc, clauses, queue, item,
+ /*outerCombined=*/false);
+ break;
+ case llvm::omp::Directive::OMPD_sections:
+ genSectionsOp(converter, symTable, semaCtx, eval, loc, clauses, queue,
+ item);
+ break;
+ case llvm::omp::Directive::OMPD_simd:
+ genSimdOp(converter, symTable, semaCtx, eval, loc, clauses, queue, item);
+ break;
+ case llvm::omp::Directive::OMPD_single:
+ genSingleOp(converter, symTable, semaCtx, eval, loc, clauses, queue, item);
+ break;
+ case llvm::omp::Directive::OMPD_target:
+ genTargetOp(converter, symTable, semaCtx, eval, loc, clauses, queue, item,
+ /*outerCombined=*/false);
+ break;
+ case llvm::omp::Directive::OMPD_target_data:
+ genTargetDataOp(converter, symTable, semaCtx, eval, loc, clauses, queue,
+ item);
+ break;
+ case llvm::omp::Directive::OMPD_target_enter_data:
+ genTargetEnterExitUpdateDataOp<mlir::omp::TargetEnterDataOp>(
+ converter, symTable, semaCtx, loc, clauses, queue, item);
+ break;
+ case llvm::omp::Directive::OMPD_target_exit_data:
+ genTargetEnterExitUpdateDataOp<mlir::omp::TargetExitDataOp>(
+ converter, symTable, semaCtx, loc, clauses, queue, item);
+ break;
+ case llvm::omp::Directive::OMPD_target_update:
+ genTargetEnterExitUpdateDataOp<mlir::omp::TargetUpdateOp>(
+ converter, symTable, semaCtx, loc, clauses, queue, item);
+ break;
+ case llvm::omp::Directive::OMPD_task:
+ genTaskOp(converter, symTable, semaCtx, eval, loc, clauses, queue, item);
+ break;
+ case llvm::omp::Directive::OMPD_taskgroup:
+ genTaskgroupOp(converter, symTable, semaCtx, eval, loc, clauses, queue,
+ item);
+ break;
+ case llvm::omp::Directive::OMPD_taskloop:
+ genTaskloopOp(converter, symTable, semaCtx, eval, loc, clauses, queue,
+ item);
+ break;
+ case llvm::omp::Directive::OMPD_teams:
+ genTeamsOp(converter, symTable, semaCtx, eval, loc, clauses, queue, item);
+ break;
+ // case llvm::omp::Directive::OMPD_workdistribute:
+ case llvm::omp::Directive::OMPD_workshare:
+ // FIXME: Workshare is not a commonly used OpenMP construct, an
+ // implementation for this feature will come later. For the codes
+ // that use this construct, add a single construct for now.
+ genSingleOp(converter, symTable, semaCtx, eval, loc, clauses, queue, item);
+ break;
+ // Composite constructs
+ case llvm::omp::Directive::OMPD_distribute_parallel_do:
+ genCompositeDistributeParallelDo(converter, symTable, semaCtx, eval, loc,
+ clauses, queue, item);
+ break;
+ case llvm::omp::Directive::OMPD_distribute_parallel_do_simd:
+ genCompositeDistributeParallelDoSimd(converter, symTable, semaCtx, eval,
+ loc, clauses, queue, item);
+ break;
+ case llvm::omp::Directive::OMPD_distribute_simd:
+ genCompositeDistributeSimd(converter, symTable, semaCtx, eval, loc, clauses,
+ queue, item);
+ break;
+ case llvm::omp::Directive::OMPD_do_simd:
+ genCompositeDoSimd(converter, symTable, semaCtx, eval, loc, clauses, queue,
+ item);
+ break;
+ case llvm::omp::Directive::OMPD_taskloop_simd:
+ genCompositeTaskloopSimd(converter, symTable, semaCtx, eval, loc, clauses,
+ queue, item);
+ break;
+ default:
+ break;
+ }
+}
+
//===----------------------------------------------------------------------===//
// OpenMPDeclarativeConstruct visitors
//===----------------------------------------------------------------------===//
@@ -2020,36 +2171,47 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
semaCtx);
mlir::Location currentLocation = converter.genLocation(directive.source);
+ ConstructQueue queue{
+ buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
+ eval, directive.source, directive.v, clauses)};
+
switch (directive.v) {
default:
break;
case llvm::omp::Directive::OMPD_barrier:
- genBarrierOp(converter, symTable, semaCtx, eval, currentLocation);
+ genBarrierOp(converter, symTable, semaCtx, eval, currentLocation, queue,
+ queue.begin());
break;
case llvm::omp::Directive::OMPD_taskwait:
- genTaskwaitOp(converter, symTable, semaCtx, eval, currentLocation, clauses);
+ genTaskwaitOp(converter, symTable, semaCtx, eval, currentLocation, clauses,
+ queue, queue.begin());
break;
case llvm::omp::Directive::OMPD_taskyield:
- genTaskyieldOp(converter, symTable, semaCtx, eval, currentLocation);
+ genTaskyieldOp(converter, symTable, semaCtx, eval, currentLocation, queue,
+ queue.begin());
break;
case llvm::omp::Directive::OMPD_target_data:
- genTargetDataOp(converter, symTable, semaCtx, eval, /*genNested=*/true,
- currentLocation, clauses);
+ genTargetDataOp(converter, symTable, semaCtx, eval, currentLocation,
+ clauses, queue, queue.begin());
break;
case llvm::omp::Directive::OMPD_target_enter_data:
genTargetEnterExitUpdateDataOp<mlir::omp::TargetEnterDataOp>(
- converter, symTable, semaCtx, currentLocation, clauses);
+ converter, symTable, semaCtx, currentLocation, clauses, queue,
+ queue.begin());
break;
case llvm::omp::Directive::OMPD_target_exit_data:
genTargetEnterExitUpdateDataOp<mlir::omp::TargetExitDataOp>(
- converter, symTable, semaCtx, currentLocation, clauses);
+ converter, symTable, semaCtx, currentLocation, clauses, queue,
+ queue.begin());
break;
case llvm::omp::Directive::OMPD_target_update:
genTargetEnterExitUpdateDataOp<mlir::omp::TargetUpdateOp>(
- converter, symTable, semaCtx, currentLocation, clauses);
+ converter, symTable, semaCtx, currentLocation, clauses, queue,
+ queue.begin());
break;
case llvm::omp::Directive::OMPD_ordered:
- genOrderedOp(converter, symTable, semaCtx, eval, currentLocation, clauses);
+ genOrderedOp(converter, symTable, semaCtx, eval, currentLocation, clauses,
+ queue, queue.begin());
break;
}
}
@@ -2073,8 +2235,12 @@ genOMP(Fortran::lower::AbstractConverter &converter,
[&](auto &&s) { return makeClause(s.v, semaCtx); })
: List<Clause>{};
mlir::Location currentLocation = converter.genLocation(verbatim.source);
+
+ ConstructQueue queue{buildConstructQueue(
+ converter.getFirOpBuilder().getModule(), semaCtx, eval, verbatim.source,
+ llvm::omp::Directive::OMPD_flush, clauses)};
genFlushOp(converter, symTable, semaCtx, eval, currentLocation, objects,
- clauses);
+ clauses, queue, queue.begin());
}
static void
@@ -2217,75 +2383,15 @@ genOMP(Fortran::lower::AbstractConverter &converter,
}
}
- std::optional<llvm::omp::Directive> nextDir = origDirective;
- bool outermostLeafConstruct = true;
- while (nextDir) {
- llvm::omp::Directive leafDir;
- std::tie(leafDir, nextDir) = splitCombinedDirective(*nextDir);
- const bool genNested = !nextDir;
- const bool outerCombined = outermostLeafConstruct && nextDir.has_value();
- switch (leafDir) {
- case llvm::omp::Directive::OMPD_master:
- // 2.16 MASTER construct.
- genMasterOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation);
- break;
- case llvm::omp::Directive::OMPD_ordered:
- // 2.17.9 ORDERED construct.
- genOrderedRegionOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, clauses);
- break;
- case llvm::omp::Directive::OMPD_parallel:
- // 2.6 PARALLEL construct.
- genParallelOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, clauses, outerCombined);
- break;
- case llvm::omp::Directive::OMPD_single:
- // 2.8.2 SINGLE construct.
- genSingleOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, clauses);
- break;
- case llvm::omp::Directive::OMPD_target:
- // 2.12.5 TARGET construct.
- genTargetOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, clauses, outerCombined);
- break;
- case llvm::omp::Directive::OMPD_target_data:
- // 2.12.2 TARGET DATA construct.
- genTargetDataOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, clauses);
- break;
- case llvm::omp::Directive::OMPD_task:
- // 2.10.1 TASK construct.
- genTaskOp(converter, symTable, semaCtx, eval, genNested, currentLocation,
- clauses);
- break;
- case llvm::omp::Directive::OMPD_taskgroup:
- // 2.17.6 TASKGROUP construct.
- genTaskgroupOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, clauses);
- break;
- case llvm::omp::Directive::OMPD_teams:
- // 2.7 TEAMS construct.
- // FIXME Pass the outerCombined argument or rename it to better describe
- // what it represents if it must always be `false` in this context.
- genTeamsOp(converter, symTable, semaCtx, eval, genNested, currentLocation,
- clauses);
- break;
- case llvm::omp::Directive::OMPD_workshare:
- // 2.8.3 WORKSHARE construct.
- // FIXME: Workshare is not a commonly used OpenMP construct, an
- // implementation for this feature will come later. For the codes
- // that use this construct, add a single construct for now.
- genSingleOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, clauses);
- break;
- default:
- llvm_unreachable("Unexpected block construct");
- break;
- }
- outermostLeafConstruct = false;
- }
+ llvm::omp::Directive directive =
+ std::get<parser::OmpBlockDirective>(beginBlockDirective.t).v;
+ const parser::CharBlock &source =
+ std::get<parser::OmpBlockDirective>(beginBlockDirective.t).source;
+ ConstructQueue queue{
+ buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
+ eval, source, directive, clauses)};
+ genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
+ queue.begin());
}
static void
@@ -2298,10 +2404,15 @@ genOMP(Fortran::lower::AbstractConverter &converter,
std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t);
List<Clause> clauses =
makeClauses(std::get<Fortran::parser::OmpClauseList>(cd.t), semaCtx);
+
+ ConstructQueue queue{buildConstructQueue(
+ converter.getFirOpBuilder().getModule(), semaCtx, eval, cd.source,
+ llvm::omp::Directive::OMPD_critical, clauses)};
+
const auto &name = std::get<std::optional<Fortran::parser::Name>>(cd.t);
mlir::Location currentLocation = converter.getCurrentLocation();
- genCriticalOp(converter, symTable, semaCtx, eval, /*genNested=*/true,
- currentLocation, clauses, name);
+ genCriticalOp(converter, symTable, semaCtx, eval, currentLocation, clauses,
+ queue, queue.begin(), name);
}
static void
@@ -2322,14 +2433,6 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t);
List<Clause> clauses = makeClauses(
std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t), semaCtx);
- mlir::Location currentLocation =
- converter.genLocation(beginLoopDirective.source);
- const auto origDirective =
- std::get<Fortran::parser::OmpLoopDirective>(beginLoopDirective.t).v;
-
- assert(llvm::omp::loopConstructSet.test(origDirective) &&
- "Expected loop construct");
-
if (auto &endLoopDirective =
std::get<std::optional<Fortran::parser::OmpEndLoopDirective>>(
loopConstruct.t)) {
@@ -2338,101 +2441,18 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
semaCtx));
}
- std::optional<llvm::omp::Directive> nextDir = origDirective;
- while (nextDir) {
- llvm::omp::Directive leafDir;
- std::tie(leafDir, nextDir) = splitCombinedDirective(*nextDir);
- if (llvm::omp::compositeConstructSet.test(leafDir)) {
- assert(!nextDir && "Composite construct cannot be split");
- switch (leafDir) {
- case llvm::omp::Directive::OMPD_distribute_parallel_do:
- // 2.9.4.3 DISTRIBUTE PARALLEL Worksharing-Loop construct.
- genCompositeDistributeParallelDo(converter, symTable, semaCtx, eval,
- clauses, currentLocation);
- break;
- case llvm::omp::Directive::OMPD_distribute_parallel_do_simd:
- // 2.9.4.4 DISTRIBUTE PARALLEL Worksharing-Loop SIMD construct.
- genCompositeDistributeParallelDoSimd(converter, symTable, semaCtx, eval,
- clauses, currentLocation);
- break;
- case llvm::omp::Directive::OMPD_distribute_simd:
- // 2.9.4.2 DISTRIBUTE SIMD construct.
- genCompositeDistributeSimd(converter, symTable, semaCtx, eval, clauses,
- currentLocation);
- break;
- case llvm::omp::Directive::OMPD_do_simd:
- // 2.9.3.2 Worksharing-Loop SIMD construct.
- genCompositeDoSimd(converter, symTable, semaCtx, eval, clauses,
- currentLocation);
- break;
- case llvm::omp::Directive::OMPD_taskloop_simd:
- // 2.10.3 TASKLOOP SIMD construct.
- genCompositeTaskloopSimd(converter, symTable, semaCtx, eval, clauses,
- currentLocation);
- break;
- default:
- llvm_unreachable("Unexpected composite construct");
- }
- } else {
- const bool genNested = !nextDir;
- switch (leafDir) {
- case llvm::omp::Directive::OMPD_distribute:
- // 2.9.4.1 DISTRIBUTE construct.
- genDistributeOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, clauses);
- break;
- case llvm::omp::Directive::OMPD_do:
- // 2.9.2 Worksharing-Loop construct.
- genWsloopOp(converter, symTable, semaCtx, eval, currentLocation,
- clauses);
- break;
- case llvm::omp::Directive::OMPD_parallel:
- // 2.6 PARALLEL construct.
- // FIXME This is not necessarily always the outer leaf construct of a
- // combined construct in this constext (e.g. distribute parallel do).
- // Maybe rename the argument if it represents something else or
- // initialize it properly.
- genParallelOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, clauses,
- /*outerCombined=*/true);
- break;
- case llvm::omp::Directive::OMPD_simd:
- // 2.9.3.1 SIMD construct.
- genSimdOp(converter, symTable, semaCtx, eval, currentLocation, clauses);
- break;
- case llvm::omp::Directive::OMPD_target:
- // 2.12.5 TARGET construct.
- genTargetOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, clauses, /*outerCombined=*/true);
- break;
- case llvm::omp::Directive::OMPD_taskloop:
- // 2.10.2 TASKLOOP construct.
- genTaskloopOp(converter, symTable, semaCtx, eval, currentLocation,
- clauses);
- break;
- case llvm::omp::Directive::OMPD_teams:
- // 2.7 TEAMS construct.
- // FIXME This is not necessarily always the outer leaf construct of a
- // combined construct in this constext (e.g. target teams distribute).
- // Maybe rename the argument if it represents something else or
- // initialize it properly.
- genTeamsOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, clauses, /*outerCombined=*/true);
- break;
- case llvm::omp::Directive::OMPD_loop:
- case llvm::omp::Directive::OMPD_masked:
- case llvm::omp::Directive::OMPD_master:
- case llvm::omp::Directive::OMPD_tile:
- case llvm::omp::Directive::OMPD_unroll:
- TODO(currentLocation, "Unhandled loop directive (" +
- llvm::omp::getOpenMPDirectiveName(leafDir) +
- ")");
- break;
- default:
- llvm_unreachable("Unexpected loop construct");
- }
- }
- }
+ mlir::Location currentLocation =
+ converter.genLocation(beginLoopDirective.source);
+
+ llvm::omp::Directive directive =
+ std::get<parser::OmpLoopDirective>(beginLoopDirective.t).v;
+ const parser::CharBlock &source =
+ std::get<parser::OmpLoopDirective>(beginLoopDirective.t).source;
+ ConstructQueue queue{
+ buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
+ eval, source, directive, clauses)};
+ genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
+ queue.begin());
}
static void
@@ -2441,8 +2461,12 @@ genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPSectionConstruct §ionConstruct) {
- // SECTION constructs are handled as a part of SECTIONS.
- llvm_unreachable("Unexpected standalone OMP SECTION");
+ mlir::Location loc = converter.getCurrentLocation();
+ ConstructQueue queue{buildConstructQueue(
+ converter.getFirOpBuilder().getModule(), semaCtx, eval,
+ sectionConstruct.source, llvm::omp::Directive::OMPD_section, {})};
+ genSectionOp(converter, symTable, semaCtx, eval, loc,
+ /*clauses=*/{}, queue, queue.begin());
}
static void
@@ -2461,77 +2485,17 @@ genOMP(Fortran::lower::AbstractConverter &converter,
clauses.append(makeClauses(
std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t),
semaCtx));
-
- // Process clauses before optional omp.parallel, so that new variables are
- // allocated outside of the parallel region
mlir::Location currentLocation = converter.getCurrentLocation();
- mlir::omp::SectionsClauseOps clauseOps;
- genSectionsClauses(converter, semaCtx, clauses, currentLocation, clauseOps);
-
- // Parallel wrapper of PARALLEL SECTIONS construct
- llvm::omp::Directive dir =
- std::get<Fortran::parser::OmpSectionsDirective>(beginSectionsDirective.t)
- .v;
- if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
- genParallelOp(converter, symTable, semaCtx, eval,
- /*genNested=*/false, currentLocation, clauses,
- /*outerCombined=*/true);
- }
-
- // Insert privatizations before SECTIONS
- symTable.pushScope();
- DataSharingProcessor dsp(converter, semaCtx, clauses, eval);
- dsp.processStep1();
-
- // SECTIONS construct.
- mlir::omp::SectionsOp sectionsOp = genSectionsOp(
- converter, symTable, semaCtx, eval, currentLocation, clauseOps);
-
- // Generate nested SECTION operations recursively.
- const auto §ionBlocks =
- std::get<Fortran::parser::OmpSectionBlocks>(sectionsConstruct.t);
- auto &firOpBuilder = converter.getFirOpBuilder();
- auto ip = firOpBuilder.saveInsertionPoint();
- mlir::omp::SectionOp lastSectionOp;
- for (const auto &[nblock, neval] :
- llvm::zip(sectionBlocks.v, eval.getNestedEvaluations())) {
- symTable.pushScope();
- lastSectionOp = genSectionOp(converter, symTable, semaCtx, neval,
- /*genNested=*/true, currentLocation);
- symTable.popScope();
- firOpBuilder.restoreInsertionPoint(ip);
- }
-
- // For `omp.sections`, lastprivatized variables occur in
- // lexically final `omp.section` operation.
- bool hasLastPrivate = false;
- if (lastSectionOp) {
- for (const Clause &clause : clauses) {
- if (const auto *lastPrivate =
- std::get_if<clause::Lastprivate>(&clause.u)) {
- hasLastPrivate = true;
- firOpBuilder.setInsertionPoint(
- lastSectionOp.getRegion().back().getTerminator());
- mlir::OpBuilder::InsertPoint lastPrivIP =
- converter.getFirOpBuilder().saveInsertionPoint();
- const auto &objList = std::get<1>(lastPrivate->t);
- for (const Object &obj : objList) {
- Fortran::semantics::Symbol *sym = obj.id();
- converter.copyHostAssociateVar(*sym, &lastPrivIP);
- }
- }
- }
- }
- // Perform DataSharingProcessor's step2 out of SECTIONS
- firOpBuilder.setInsertionPointAfter(sectionsOp.getOperation());
- dsp.processStep2(sectionsOp, false);
- // Emit implicit barrier to synchronize threads and avoid data
- // races on post-update of lastprivate variables when `nowait`
- // clause is present.
- if (clauseOps.nowaitAttr && hasLastPrivate)
- firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation());
- symTable.popScope();
+ llvm::omp::Directive directive =
+ std::get<parser::OmpSectionsDirective>(beginSectionsDirective.t).v;
+ const parser::CharBlock &source =
+ std::get<parser::OmpSectionsDirective>(beginSectionsDirective.t).source;
+ ConstructQueue queue{
+ buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
+ eval, source, directive, clauses)};
+ genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
+ queue.begin());
}
static void genOMP(Fortran::lower::AbstractConverter &converter,
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index eed63b226133a..cb1d1a5a7f3dd 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -51,12 +51,6 @@ int64_t getCollapseValue(const List<Clause> &clauses) {
return 1;
}
-uint32_t getOpenMPVersion(mlir::ModuleOp mod) {
- if (mlir::Attribute verAttr = mod->getAttr("omp.version"))
- return llvm::cast<mlir::omp::VersionAttr>(verAttr).getVersion();
- llvm_unreachable("Expecting OpenMP version attribute in module");
-}
-
void genObjectList(const ObjectList &objects,
Fortran::lower::AbstractConverter &converter,
llvm::SmallVectorImpl<mlir::Value> &operands) {
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 8fbb18fa8656f..345ce55620ee9 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -93,7 +93,6 @@ void gatherFuncAndVarSyms(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause);
int64_t getCollapseValue(const List<Clause> &clauses);
-uint32_t getOpenMPVersion(mlir::ModuleOp mod);
Fortran::semantics::Symbol *
getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject);
diff --git a/flang/test/Lower/OpenMP/default-clause-byref.f90 b/flang/test/Lower/OpenMP/default-clause-byref.f90
index 62ba67e5962f4..7cc2bc2e0c710 100644
--- a/flang/test/Lower/OpenMP/default-clause-byref.f90
+++ b/flang/test/Lower/OpenMP/default-clause-byref.f90
@@ -161,12 +161,12 @@ subroutine nested_default_clause_tests
!CHECK: %[[Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFnested_default_clause_testsEz"}
!CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel {
+!CHECK: %[[PRIVATE_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFnested_default_clause_testsEy"}
+!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
!CHECK: %[[PRIVATE_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[TEMP:.*]] = fir.load %[[X_DECL]]#0 : !fir.ref<i32>
!CHECK: hlfir.assign %[[TEMP]] to %[[PRIVATE_X_DECL]]#0 temporary_lhs : i32, !fir.ref<i32>
-!CHECK: %[[PRIVATE_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFnested_default_clause_testsEy"}
-!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_default_clause_testsEz"}
!CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_K:.*]] = fir.alloca i32 {bindc_name = "k", pinned, uniq_name = "_QFnested_default_clause_testsEk"}
@@ -221,6 +221,7 @@ subroutine nested_default_clause_tests
!CHECK: omp.parallel {
+!CHECK: %[[PRIVATE_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
!CHECK: %[[PRIVATE_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFnested_default_clause_testsEy"}
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
diff --git a/flang/test/Lower/OpenMP/default-clause.f90 b/flang/test/Lower/OpenMP/default-clause.f90
index a90f0f4ef5f84..843ee6bb7910b 100644
--- a/flang/test/Lower/OpenMP/default-clause.f90
+++ b/flang/test/Lower/OpenMP/default-clause.f90
@@ -160,12 +160,12 @@ end program default_clause_lowering
!CHECK: %[[Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFnested_default_clause_test1Ez"}
!CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z]] {uniq_name = "_QFnested_default_clause_test1Ez"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel {
+!CHECK: %[[PRIVATE_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFnested_default_clause_test1Ey"}
+!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_test1Ey"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_test1Ex"}
!CHECK: %[[PRIVATE_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_X]] {uniq_name = "_QFnested_default_clause_test1Ex"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[TEMP:.*]] = fir.load %[[X_DECL]]#0 : !fir.ref<i32>
!CHECK: hlfir.assign %[[TEMP]] to %[[PRIVATE_X_DECL]]#0 temporary_lhs : i32, !fir.ref<i32>
-!CHECK: %[[PRIVATE_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFnested_default_clause_test1Ey"}
-!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_test1Ey"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_default_clause_test1Ez"}
!CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_test1Ez"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_K:.*]] = fir.alloca i32 {bindc_name = "k", pinned, uniq_name = "_QFnested_default_clause_test1Ek"}
diff --git a/flang/test/Lower/OpenMP/parallel-lastprivate-clause-scalar.f90 b/flang/test/Lower/OpenMP/parallel-lastprivate-clause-scalar.f90
index b7f11c8c722f7..e6ee75c8a5bef 100644
--- a/flang/test/Lower/OpenMP/parallel-lastprivate-clause-scalar.f90
+++ b/flang/test/Lower/OpenMP/parallel-lastprivate-clause-scalar.f90
@@ -145,10 +145,10 @@ subroutine mult_lastprivate_int(arg1, arg2)
!CHECK: %[[ARG1_DECL:.*]]:2 = hlfir.declare %[[ARG1]] dummy_scope %{{[0-9]+}} {uniq_name = "_QFmult_lastprivate_int2Earg1"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[ARG2_DECL:.*]]:2 = hlfir.declare %[[ARG2]] dummy_scope %{{[0-9]+}} {uniq_name = "_QFmult_lastprivate_int2Earg2"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel {
-!CHECK-DAG: %[[CLONE1:.*]] = fir.alloca i32 {bindc_name = "arg1"
-!CHECK-DAG: %[[CLONE1_DECL:.*]]:2 = hlfir.declare %[[CLONE1]] {uniq_name = "_QFmult_lastprivate_int2Earg1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK-DAG: %[[CLONE2:.*]] = fir.alloca i32 {bindc_name = "arg2"
!CHECK-DAG: %[[CLONE2_DECL:.*]]:2 = hlfir.declare %[[CLONE2]] {uniq_name = "_QFmult_lastprivate_int2Earg2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK-DAG: %[[CLONE1:.*]] = fir.alloca i32 {bindc_name = "arg1"
+!CHECK-DAG: %[[CLONE1_DECL:.*]]:2 = hlfir.declare %[[CLONE1]] {uniq_name = "_QFmult_lastprivate_int2Earg1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.wsloop {
!CHECK-NEXT: omp.loop_nest (%[[INDX_WS:.*]]) : {{.*}} {
diff --git a/llvm/include/llvm/Frontend/OpenMP/ClauseT.h b/llvm/include/llvm/Frontend/OpenMP/ClauseT.h
index daef02bcfc9a3..07c95497b7a41 100644
--- a/llvm/include/llvm/Frontend/OpenMP/ClauseT.h
+++ b/llvm/include/llvm/Frontend/OpenMP/ClauseT.h
@@ -178,6 +178,12 @@ template <typename T> using ListT = llvm::SmallVector<T, 0>;
// provide their own specialization that conforms to the above requirements.
template <typename IdType, typename ExprType> struct ObjectT;
+// By default, object equality is only determined by its identity.
+template <typename I, typename E>
+bool operator==(const ObjectT<I, E> &o1, const ObjectT<I, E> &o2) {
+ return o1.id() == o2.id();
+}
+
template <typename I, typename E> using ObjectListT = ListT<ObjectT<I, E>>;
using DirectiveName = llvm::omp::Directive;
@@ -264,6 +270,32 @@ struct ReductionIdentifierT {
template <typename T, typename I, typename E> //
using IteratorT = ListT<IteratorSpecifierT<T, I, E>>;
+
+template <typename T>
+std::enable_if_t<T::EmptyTrait::value, bool> operator==(const T &a,
+ const T &b) {
+ return true;
+}
+template <typename T>
+std::enable_if_t<T::IncompleteTrait::value, bool> operator==(const T &a,
+ const T &b) {
+ return true;
+}
+template <typename T>
+std::enable_if_t<T::WrapperTrait::value, bool> operator==(const T &a,
+ const T &b) {
+ return a.v == b.v;
+}
+template <typename T>
+std::enable_if_t<T::TupleTrait::value, bool> operator==(const T &a,
+ const T &b) {
+ return a.t == b.t;
+}
+template <typename T>
+std::enable_if_t<T::UnionTrait::value, bool> operator==(const T &a,
+ const T &b) {
+ return a.u == b.u;
+}
} // namespace type
template <typename T> using ListT = type::ListT<T>;
@@ -285,6 +317,8 @@ ListT<ResultTy> makeList(ContainerTy &&container, FunctionTy &&func) {
}
namespace clause {
+using type::operator==;
+
// V5.2: [8.3.1] `assumption` clauses
template <typename T, typename I, typename E> //
struct AbsentT {
@@ -726,7 +760,7 @@ struct LinearT {
ENUM(LinearModifier, Ref, Val, Uval);
using TupleTrait = std::true_type;
- // Step == nullptr means 1.
+ // Step == nullopt means 1.
std::tuple<OPT(StepSimpleModifier), OPT(StepComplexModifier),
OPT(LinearModifier), List>
t;
@@ -1142,9 +1176,11 @@ struct UsesAllocatorsT {
using MemSpace = E;
using TraitsArray = ObjectT<I, E>;
using Allocator = E;
- using AllocatorSpec =
- std::tuple<OPT(MemSpace), OPT(TraitsArray), Allocator>; // Not a spec name
- using Allocators = ListT<AllocatorSpec>; // Not a spec name
+ struct AllocatorSpec { // Not a spec name
+ using TupleTrait = std::true_type;
+ std::tuple<OPT(MemSpace), OPT(TraitsArray), Allocator> t;
+ };
+ using Allocators = ListT<AllocatorSpec>; // Not a spec name
using WrapperTrait = std::true_type;
Allocators v;
};
@@ -1232,9 +1268,10 @@ using UnionOfAllClausesT = typename type::Union< //
UnionClausesT<T, I, E>, //
WrapperClausesT<T, I, E> //
>::type;
-
} // namespace clause
+using type::operator==;
+
// The variant wrapper that encapsulates all possible specific clauses.
// The `Extras` arguments are additional types representing local extensions
// to the clause set, e.g.
@@ -1260,6 +1297,11 @@ struct ClauseT {
VariantTy u;
};
+template <typename ClauseType> struct DirectiveWithClauses {
+ llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown;
+ tomp::type::ListT<ClauseType> clauses;
+};
+
} // namespace tomp
#undef OPT
diff --git a/llvm/include/llvm/Frontend/OpenMP/ConstructCompositionT.h b/llvm/include/llvm/Frontend/OpenMP/ConstructCompositionT.h
new file mode 100644
index 0000000000000..7a4ed92a10703
--- /dev/null
+++ b/llvm/include/llvm/Frontend/OpenMP/ConstructCompositionT.h
@@ -0,0 +1,403 @@
+//===- ConstructCompositionT.h -- Composing compound constructs -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Given a list of leaf construct, each with a set of clauses, generate the
+// compound construct whose leaf constructs are the given list, and whose clause
+// list is the merged lists of individual leaf clauses.
+//
+// *** At the moment it assumes that the individual constructs and their clauses
+// *** are a subset of those created by splitting a valid compound construct.
+//===----------------------------------------------------------------------===//
+#ifndef LLVM_FRONTEND_OPENMP_CONSTRUCTCOMPOSITIONT_H
+#define LLVM_FRONTEND_OPENMP_CONSTRUCTCOMPOSITIONT_H
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Frontend/OpenMP/ClauseT.h"
+#include "llvm/Frontend/OpenMP/OMP.h"
+
+#include <iterator>
+#include <optional>
+#include <tuple>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+namespace tomp {
+template <typename ClauseType> struct ConstructCompositionT {
+ using ClauseTy = ClauseType;
+
+ using TypeTy = typename ClauseTy::TypeTy;
+ using IdTy = typename ClauseTy::IdTy;
+ using ExprTy = typename ClauseTy::ExprTy;
+
+ ConstructCompositionT(uint32_t version,
+ llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> leafs);
+
+ DirectiveWithClauses<ClauseTy> merged;
+
+private:
+ // Use an ordered container, since we beed to maintain the order in which
+ // clauses are added to it. This is to avoid non-deterministic output.
+ using ClauseSet = ListT<ClauseTy>;
+
+ enum class Presence {
+ All, // Clause is preesnt on all leaf constructs that allow it.
+ Some, // Clause is present on some, but not on all constructs.
+ None, // Clause is absent on all constructs.
+ };
+
+ template <typename S>
+ ClauseTy makeClause(llvm::omp::Clause clauseId, S &&specific) {
+ return ClauseTy{clauseId, std::move(specific)};
+ }
+
+ llvm::omp::Directive
+ makeCompound(llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> parts);
+
+ Presence checkPresence(llvm::omp::Clause clauseId);
+
+ // There are clauses that need special handling:
+ // 1. "if": the "directive-name-modifier" on the merged clause may need
+ // to be set appropriately.
+ // 2. "reduction": implies "privateness" of all objects (incompatible
+ // with "shared"); there are rules for merging modifiers
+ void mergeIf();
+ void mergeReduction();
+ void mergeDSA();
+
+ uint32_t version;
+ llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> leafs;
+
+ // clause id -> set of leaf constructs that contain it
+ std::unordered_map<llvm::omp::Clause, llvm::BitVector> clausePresence;
+ // clause id -> set of instances of that clause
+ std::unordered_map<llvm::omp::Clause, ClauseSet> clauseSets;
+};
+
+template <typename C>
+ConstructCompositionT<C>::ConstructCompositionT(
+ uint32_t version, llvm::ArrayRef<DirectiveWithClauses<C>> leafs)
+ : version(version), leafs(leafs) {
+ // Merge the list of constructs with clauses into a compound construct
+ // with a single list of clauses.
+ // The intended use of this function is in splitting compound constructs,
+ // while preserving composite constituent constructs:
+ // Step 1: split compound construct into leaf constructs.
+ // Step 2: identify composite sub-construct, and merge the constituent leafs.
+ //
+ // *** At the moment it assumes that the individual constructs and their
+ // *** clauses are a subset of those created by splitting a valid compound
+ // *** construct.
+ //
+ // 1. Deduplicate clauses
+ // - exact duplicates: e.g. shared(x) shared(x) -> shared(x)
+ // - special cases of clauses
diff ering in modifier:
+ // (a) reduction: inscan + (none|default) = inscan
+ // (b) reduction: task + (none|default) = task
+ // (c) combine repeated "if" clauses if possible
+ // 2. Merge DSA clauses: e.g. private(x) private(y) -> private(x, y).
+ // 3. Resolve potential DSA conflicts (typically due to implied clauses).
+
+ if (leafs.empty())
+ return;
+
+ merged.id = makeCompound(leafs);
+
+ // Populate the two maps:
+ for (const auto &[index, leaf] : llvm::enumerate(leafs)) {
+ for (const auto &clause : leaf.clauses) {
+ // Update clausePresence.
+ auto &pset = clausePresence[clause.id];
+ if (pset.size() < leafs.size())
+ pset.resize(leafs.size());
+ pset.set(index);
+ // Update clauseSets.
+ ClauseSet &cset = clauseSets[clause.id];
+ if (!llvm::is_contained(cset, clause))
+ cset.push_back(clause);
+ }
+ }
+
+ mergeIf();
+ mergeReduction();
+ mergeDSA();
+
+ // Fir the rest of the clauses, just copy them.
+ for (auto &[id, clauses] : clauseSets) {
+ // Skip clauses we've already dealt with.
+ switch (id) {
+ case llvm::omp::Clause::OMPC_if:
+ case llvm::omp::Clause::OMPC_reduction:
+ case llvm::omp::Clause::OMPC_shared:
+ case llvm::omp::Clause::OMPC_private:
+ case llvm::omp::Clause::OMPC_firstprivate:
+ case llvm::omp::Clause::OMPC_lastprivate:
+ continue;
+ default:
+ break;
+ }
+ llvm::append_range(merged.clauses, clauses);
+ }
+}
+
+template <typename C>
+llvm::omp::Directive ConstructCompositionT<C>::makeCompound(
+ llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> parts) {
+ llvm::SmallVector<llvm::omp::Directive> dirIds;
+ llvm::transform(parts, std::back_inserter(dirIds),
+ [](auto &&dwc) { return dwc.id; });
+
+ return llvm::omp::getCompoundConstruct(dirIds);
+}
+
+template <typename C>
+auto ConstructCompositionT<C>::checkPresence(llvm::omp::Clause clauseId)
+ -> Presence {
+ auto found = clausePresence.find(clauseId);
+ if (found == clausePresence.end())
+ return Presence::None;
+
+ bool OnAll = true, OnNone = true;
+ for (const auto &[index, leaf] : llvm::enumerate(leafs)) {
+ if (!llvm::omp::isAllowedClauseForDirective(leaf.id, clauseId, version))
+ continue;
+
+ if (found->second.test(index))
+ OnNone = false;
+ else
+ OnAll = false;
+ }
+
+ if (OnNone)
+ return Presence::None;
+ if (OnAll)
+ return Presence::All;
+ return Presence::Some;
+}
+
+template <typename C> void ConstructCompositionT<C>::mergeIf() {
+ using IfTy = tomp::clause::IfT<TypeTy, IdTy, ExprTy>;
+ // Deal with the "if" clauses. If it's on all leafs that allow it, then it
+ // will apply to the compound construct. Otherwise it will apply to the
+ // single (assumed) leaf construct.
+ // This assumes that the "if" clauses have the same expression.
+ Presence presence = checkPresence(llvm::omp::Clause::OMPC_if);
+ if (presence == Presence::None)
+ return;
+
+ const ClauseTy &some = *clauseSets[llvm::omp::Clause::OMPC_if].begin();
+ const auto &someIf = std::get<IfTy>(some.u);
+
+ if (presence == Presence::All) {
+ // Create "if" without "directive-name-modifier".
+ merged.clauses.emplace_back(
+ makeClause(llvm::omp::Clause::OMPC_if,
+ IfTy{{/*DirectiveNameModifier=*/std::nullopt,
+ /*IfExpression=*/std::get<typename IfTy::IfExpression>(
+ someIf.t)}}));
+ } else {
+ // Find out where it's present and create "if" with the corresponding
+ // "directive-name-modifier".
+ int Idx = clausePresence[llvm::omp::Clause::OMPC_if].find_first();
+ assert(Idx >= 0);
+ merged.clauses.emplace_back(
+ makeClause(llvm::omp::Clause::OMPC_if,
+ IfTy{{/*DirectiveNameModifier=*/leafs[Idx].id,
+ /*IfExpression=*/std::get<typename IfTy::IfExpression>(
+ someIf.t)}}));
+ }
+}
+
+template <typename C> void ConstructCompositionT<C>::mergeReduction() {
+ Presence presence = checkPresence(llvm::omp::Clause::OMPC_reduction);
+ if (presence == Presence::None)
+ return;
+
+ using ReductionTy = tomp::clause::ReductionT<TypeTy, IdTy, ExprTy>;
+ using ModifierTy = typename ReductionTy::ReductionModifier;
+ using IdentifiersTy = typename ReductionTy::ReductionIdentifiers;
+ using ListTy = typename ReductionTy::List;
+ // There are exceptions on which constructs "reduction" may appear
+ // (specifically "parallel", and "teams"). Assume that if "reduction"
+ // is present, it can be applied to the compound construct.
+
+ // What's left is to see if there are any modifiers present. Again,
+ // assume that there are no conflicting modifiers.
+ // There can be, however, multiple reductions on
diff erent objects.
+ auto equal = [](const ClauseTy &red1, const ClauseTy &red2) {
+ // Extract actual reductions.
+ const auto r1 = std::get<ReductionTy>(red1.u);
+ const auto r2 = std::get<ReductionTy>(red2.u);
+ // Compare everything except modifiers.
+ if (std::get<IdentifiersTy>(r1.t) != std::get<IdentifiersTy>(r2.t))
+ return false;
+ if (std::get<ListTy>(r1.t) != std::get<ListTy>(r2.t))
+ return false;
+ return true;
+ };
+
+ auto getModifier = [](const ClauseTy &clause) {
+ const ReductionTy &red = std::get<ReductionTy>(clause.u);
+ return std::get<std::optional<ModifierTy>>(red.t);
+ };
+
+ const ClauseSet &reductions = clauseSets[llvm::omp::Clause::OMPC_reduction];
+ std::unordered_set<const ClauseTy *> visited;
+ while (reductions.size() != visited.size()) {
+ typename ClauseSet::const_iterator first;
+
+ // Find first non-visited reduction.
+ for (first = reductions.begin(); first != reductions.end(); ++first) {
+ if (visited.count(&*first))
+ continue;
+ visited.insert(&*first);
+ break;
+ }
+
+ std::optional<ModifierTy> modifier = getModifier(*first);
+
+ // Visit all other reductions that are "equal" (with respect to the
+ // definition above) to "first". Collect modifiers.
+ for (auto iter = std::next(first); iter != reductions.end(); ++iter) {
+ if (!equal(*first, *iter))
+ continue;
+ visited.insert(&*iter);
+ if (!modifier || *modifier == ModifierTy::Default)
+ modifier = getModifier(*iter);
+ }
+
+ const auto &firstRed = std::get<ReductionTy>(first->u);
+ merged.clauses.emplace_back(makeClause(
+ llvm::omp::Clause::OMPC_reduction,
+ ReductionTy{
+ {/*ReductionModifier=*/modifier,
+ /*ReductionIdentifiers=*/std::get<IdentifiersTy>(firstRed.t),
+ /*List=*/std::get<ListTy>(firstRed.t)}}));
+ }
+}
+
+template <typename C> void ConstructCompositionT<C>::mergeDSA() {
+ using ObjectTy = tomp::type::ObjectT<IdTy, ExprTy>;
+
+ // Resolve data-sharing attributes.
+ enum DSA : int {
+ None = 0,
+ Shared = 1 << 0,
+ Private = 1 << 1,
+ FirstPrivate = 1 << 2,
+ LastPrivate = 1 << 3,
+ LastPrivateConditional = 1 << 4,
+ };
+
+ // Use ordered containers to avoid non-deterministic output.
+ llvm::SmallVector<std::pair<ObjectTy, int>> objectDsa;
+
+ auto getDsa = [&](const ObjectTy &object) -> std::pair<ObjectTy, int> & {
+ auto found = llvm::find_if(objectDsa, [&](std::pair<ObjectTy, int> &p) {
+ return p.first.id() == object.id();
+ });
+ if (found != objectDsa.end())
+ return *found;
+ return objectDsa.emplace_back(object, DSA::None);
+ };
+
+ using SharedTy = tomp::clause::SharedT<TypeTy, IdTy, ExprTy>;
+ using PrivateTy = tomp::clause::PrivateT<TypeTy, IdTy, ExprTy>;
+ using FirstprivateTy = tomp::clause::FirstprivateT<TypeTy, IdTy, ExprTy>;
+ using LastprivateTy = tomp::clause::LastprivateT<TypeTy, IdTy, ExprTy>;
+
+ // Visit clauses that affect DSA.
+ for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_shared]) {
+ for (auto &object : std::get<SharedTy>(clause.u).v)
+ getDsa(object).second |= DSA::Shared;
+ }
+
+ for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_private]) {
+ for (auto &object : std::get<PrivateTy>(clause.u).v)
+ getDsa(object).second |= DSA::Private;
+ }
+
+ for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_firstprivate]) {
+ for (auto &object : std::get<FirstprivateTy>(clause.u).v)
+ getDsa(object).second |= DSA::FirstPrivate;
+ }
+
+ for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_lastprivate]) {
+ using ModifierTy = typename LastprivateTy::LastprivateModifier;
+ using ListTy = typename LastprivateTy::List;
+ const auto &lastp = std::get<LastprivateTy>(clause.u);
+ for (auto &object : std::get<ListTy>(lastp.t)) {
+ auto &mod = std::get<std::optional<ModifierTy>>(lastp.t);
+ if (mod && *mod == ModifierTy::Conditional) {
+ getDsa(object).second |= DSA::LastPrivateConditional;
+ } else {
+ getDsa(object).second |= DSA::LastPrivate;
+ }
+ }
+ }
+
+ // Check reductions as well, clear "shared" if set.
+ for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_reduction]) {
+ using ReductionTy = tomp::clause::ReductionT<TypeTy, IdTy, ExprTy>;
+ using ListTy = typename ReductionTy::List;
+ for (auto &object : std::get<ListTy>(std::get<ReductionTy>(clause.u).t))
+ getDsa(object).second &= ~DSA::Shared;
+ }
+
+ tomp::ListT<ObjectTy> privateObj, sharedObj, firstpObj, lastpObj, lastpcObj;
+ for (auto &[object, dsa] : objectDsa) {
+ if (dsa &
+ (DSA::FirstPrivate | DSA::LastPrivate | DSA::LastPrivateConditional)) {
+ if (dsa & DSA::FirstPrivate)
+ firstpObj.push_back(object); // no else
+ if (dsa & DSA::LastPrivateConditional)
+ lastpcObj.push_back(object);
+ else if (dsa & DSA::LastPrivate)
+ lastpObj.push_back(object);
+ } else if (dsa & DSA::Private) {
+ privateObj.push_back(object);
+ } else if (dsa & DSA::Shared) {
+ sharedObj.push_back(object);
+ }
+ }
+
+ // Materialize each clause.
+ if (!privateObj.empty()) {
+ merged.clauses.emplace_back(
+ makeClause(llvm::omp::Clause::OMPC_private,
+ PrivateTy{/*List=*/std::move(privateObj)}));
+ }
+ if (!sharedObj.empty()) {
+ merged.clauses.emplace_back(
+ makeClause(llvm::omp::Clause::OMPC_shared,
+ SharedTy{/*List=*/std::move(sharedObj)}));
+ }
+ if (!firstpObj.empty()) {
+ merged.clauses.emplace_back(
+ makeClause(llvm::omp::Clause::OMPC_firstprivate,
+ FirstprivateTy{/*List=*/std::move(firstpObj)}));
+ }
+ if (!lastpObj.empty()) {
+ merged.clauses.emplace_back(
+ makeClause(llvm::omp::Clause::OMPC_lastprivate,
+ LastprivateTy{{/*LastprivateModifier=*/std::nullopt,
+ /*List=*/std::move(lastpObj)}}));
+ }
+ if (!lastpcObj.empty()) {
+ auto conditional = LastprivateTy::LastprivateModifier::Conditional;
+ merged.clauses.emplace_back(
+ makeClause(llvm::omp::Clause::OMPC_lastprivate,
+ LastprivateTy{{/*LastprivateModifier=*/conditional,
+ /*List=*/std::move(lastpcObj)}}));
+ }
+}
+} // namespace tomp
+
+#endif // LLVM_FRONTEND_OPENMP_CONSTRUCTCOMPOSITIONT_H
diff --git a/llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h b/llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h
new file mode 100644
index 0000000000000..37c88f0fa07b4
--- /dev/null
+++ b/llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h
@@ -0,0 +1,1161 @@
+//===- ConstructDecompositionT.h -- Decomposing compound constructs -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Given a compound construct with a set of clauses, generate the list of
+// constituent leaf constructs, each with a list of clauses that apply to it.
+//
+// Note: Clauses that are not originally present, but that are implied by the
+// OpenMP spec are materialized, and are present in the output.
+//
+// Note: Composite constructs will also be broken up into leaf constructs.
+// If composite constructs require processing as a whole, the lists of clauses
+// for each leaf constituent should be merged.
+//===----------------------------------------------------------------------===//
+#ifndef LLVM_FRONTEND_OPENMP_CONSTRUCTDECOMPOSITIONT_H
+#define LLVM_FRONTEND_OPENMP_CONSTRUCTDECOMPOSITIONT_H
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/Frontend/OpenMP/ClauseT.h"
+#include "llvm/Frontend/OpenMP/OMP.h"
+
+#include <iterator>
+#include <list>
+#include <optional>
+#include <tuple>
+#include <type_traits>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <variant>
+
+static inline llvm::ArrayRef<llvm::omp::Directive> getWorksharing() {
+ static llvm::omp::Directive worksharing[] = {
+ llvm::omp::Directive::OMPD_do, llvm::omp::Directive::OMPD_for,
+ llvm::omp::Directive::OMPD_scope, llvm::omp::Directive::OMPD_sections,
+ llvm::omp::Directive::OMPD_single, llvm::omp::Directive::OMPD_workshare,
+ };
+ return worksharing;
+}
+
+static inline llvm::ArrayRef<llvm::omp::Directive> getWorksharingLoop() {
+ static llvm::omp::Directive worksharingLoop[] = {
+ llvm::omp::Directive::OMPD_do,
+ llvm::omp::Directive::OMPD_for,
+ };
+ return worksharingLoop;
+}
+
+namespace detail {
+template <typename Container, typename Predicate>
+typename std::remove_reference_t<Container>::iterator
+find_unique(Container &&container, Predicate &&pred) {
+ auto first = std::find_if(container.begin(), container.end(), pred);
+ if (first == container.end())
+ return first;
+ auto second = std::find_if(std::next(first), container.end(), pred);
+ if (second == container.end())
+ return first;
+ return container.end();
+}
+
+} // namespace detail
+
+namespace tomp {
+
+// ClauseType - Either instance of ClauseT, or a type derived from ClauseT.
+//
+// This is the clause representation in the code using this infrastructure.
+//
+// HelperType - A class that implements two member functions:
+//
+// // Return the base object of the given object, if any.
+// std::optional<Object> getBaseObject(const Object &object) const
+// // Return the iteration variable of the outermost loop associated
+// // with the construct being worked on, if any.
+// std::optional<Object> getLoopIterVar() const
+template <typename ClauseType, typename HelperType>
+struct ConstructDecompositionT {
+ using ClauseTy = ClauseType;
+
+ using TypeTy = typename ClauseTy::TypeTy;
+ using IdTy = typename ClauseTy::IdTy;
+ using ExprTy = typename ClauseTy::ExprTy;
+ using HelperTy = HelperType;
+ using ObjectTy = tomp::ObjectT<IdTy, ExprTy>;
+
+ using ClauseSet = std::unordered_set<const ClauseTy *>;
+
+ ConstructDecompositionT(uint32_t ver, HelperType &helper,
+ llvm::omp::Directive dir,
+ llvm::ArrayRef<ClauseTy> clauses)
+ : version(ver), construct(dir), helper(helper) {
+ for (const ClauseTy &clause : clauses)
+ nodes.push_back(&clause);
+
+ bool success = split();
+ if (!success)
+ return;
+
+ // Copy the individual leaf directives with their clauses to the
+ // output list. Copy by value, since we don't own the storage
+ // with the input clauses, and the internal representation uses
+ // clause addresses.
+ for (auto &leaf : leafs) {
+ output.push_back({leaf.id});
+ auto &out = output.back();
+ for (const ClauseTy *c : leaf.clauses)
+ out.clauses.push_back(*c);
+ }
+ }
+
+ tomp::ListT<DirectiveWithClauses<ClauseType>> output;
+
+private:
+ bool split();
+
+ struct LeafReprInternal {
+ llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown;
+ tomp::type::ListT<const ClauseTy *> clauses;
+ };
+
+ LeafReprInternal *findDirective(llvm::omp::Directive dirId) {
+ auto found = llvm::find_if(
+ leafs, [&](const LeafReprInternal &leaf) { return leaf.id == dirId; });
+ return found != leafs.end() ? &*found : nullptr;
+ }
+
+ ClauseSet *findClausesWith(const ObjectTy &object) {
+ if (auto found = syms.find(object.id()); found != syms.end())
+ return &found->second;
+ return nullptr;
+ }
+
+ template <typename S>
+ ClauseTy *makeClause(llvm::omp::Clause clauseId, S &&specific) {
+ implicit.push_back(ClauseTy{clauseId, std::move(specific)});
+ return &implicit.back();
+ }
+
+ void addClauseSymsToMap(const ObjectTy &object, const ClauseTy *);
+ void addClauseSymsToMap(const tomp::ObjectListT<IdTy, ExprTy> &objects,
+ const ClauseTy *);
+ void addClauseSymsToMap(const TypeTy &item, const ClauseTy *);
+ void addClauseSymsToMap(const ExprTy &item, const ClauseTy *);
+ void addClauseSymsToMap(const tomp::clause::MapT<TypeTy, IdTy, ExprTy> &item,
+ const ClauseTy *);
+
+ template <typename U>
+ void addClauseSymsToMap(const std::optional<U> &item, const ClauseTy *);
+ template <typename U>
+ void addClauseSymsToMap(const tomp::ListT<U> &item, const ClauseTy *);
+ template <typename... U, size_t... Is>
+ void addClauseSymsToMap(const std::tuple<U...> &item, const ClauseTy *,
+ std::index_sequence<Is...> = {});
+ template <typename U>
+ std::enable_if_t<std::is_enum_v<llvm::remove_cvref_t<U>>, void>
+ addClauseSymsToMap(U &&item, const ClauseTy *);
+
+ template <typename U>
+ std::enable_if_t<llvm::remove_cvref_t<U>::EmptyTrait::value, void>
+ addClauseSymsToMap(U &&item, const ClauseTy *);
+
+ template <typename U>
+ std::enable_if_t<llvm::remove_cvref_t<U>::IncompleteTrait::value, void>
+ addClauseSymsToMap(U &&item, const ClauseTy *);
+
+ template <typename U>
+ std::enable_if_t<llvm::remove_cvref_t<U>::WrapperTrait::value, void>
+ addClauseSymsToMap(U &&item, const ClauseTy *);
+
+ template <typename U>
+ std::enable_if_t<llvm::remove_cvref_t<U>::TupleTrait::value, void>
+ addClauseSymsToMap(U &&item, const ClauseTy *);
+
+ template <typename U>
+ std::enable_if_t<llvm::remove_cvref_t<U>::UnionTrait::value, void>
+ addClauseSymsToMap(U &&item, const ClauseTy *);
+
+ // Apply a clause to the only directive that allows it. If there are no
+ // directives that allow it, or if there is more that one, do not apply
+ // anything and return false, otherwise return true.
+ bool applyToUnique(const ClauseTy *node);
+
+ // Apply a clause to the first directive in given range that allows it.
+ // If such a directive does not exist, return false, otherwise return true.
+ template <typename Iterator>
+ bool applyToFirst(const ClauseTy *node, llvm::iterator_range<Iterator> range);
+
+ // Apply a clause to the innermost directive that allows it. If such a
+ // directive does not exist, return false, otherwise return true.
+ bool applyToInnermost(const ClauseTy *node);
+
+ // Apply a clause to the outermost directive that allows it. If such a
+ // directive does not exist, return false, otherwise return true.
+ bool applyToOutermost(const ClauseTy *node);
+
+ template <typename Predicate>
+ bool applyIf(const ClauseTy *node, Predicate shouldApply);
+
+ bool applyToAll(const ClauseTy *node);
+
+ template <typename Clause>
+ bool applyClause(Clause &&clause, const ClauseTy *node);
+
+ bool applyClause(const tomp::clause::CollapseT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *);
+ bool applyClause(const tomp::clause::PrivateT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *);
+ bool
+ applyClause(const tomp::clause::FirstprivateT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *);
+ bool
+ applyClause(const tomp::clause::LastprivateT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *);
+ bool applyClause(const tomp::clause::SharedT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *);
+ bool applyClause(const tomp::clause::DefaultT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *);
+ bool
+ applyClause(const tomp::clause::ThreadLimitT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *);
+ bool applyClause(const tomp::clause::OrderT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *);
+ bool applyClause(const tomp::clause::AllocateT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *);
+ bool applyClause(const tomp::clause::ReductionT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *);
+ bool applyClause(const tomp::clause::IfT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *);
+ bool applyClause(const tomp::clause::LinearT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *);
+ bool applyClause(const tomp::clause::NowaitT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *);
+
+ uint32_t version;
+ llvm::omp::Directive construct;
+ HelperType &helper;
+ ListT<LeafReprInternal> leafs;
+ tomp::ListT<const ClauseTy *> nodes;
+ std::list<ClauseTy> implicit; // Container for materialized implicit clauses.
+ // Inserting must preserve element addresses.
+ std::unordered_map<IdTy, ClauseSet> syms;
+ std::unordered_set<IdTy> mapBases;
+};
+
+// Deduction guide
+template <typename ClauseType, typename HelperType>
+ConstructDecompositionT(uint32_t, HelperType &, llvm::omp::Directive,
+ llvm::ArrayRef<ClauseType>)
+ -> ConstructDecompositionT<ClauseType, HelperType>;
+
+template <typename C, typename H>
+void ConstructDecompositionT<C, H>::addClauseSymsToMap(const ObjectTy &object,
+ const ClauseTy *node) {
+ syms[object.id()].insert(node);
+}
+
+template <typename C, typename H>
+void ConstructDecompositionT<C, H>::addClauseSymsToMap(
+ const tomp::ObjectListT<IdTy, ExprTy> &objects, const ClauseTy *node) {
+ for (auto &object : objects)
+ syms[object.id()].insert(node);
+}
+
+template <typename C, typename H>
+void ConstructDecompositionT<C, H>::addClauseSymsToMap(const TypeTy &item,
+ const ClauseTy *node) {
+ // Nothing to do for types.
+}
+
+template <typename C, typename H>
+void ConstructDecompositionT<C, H>::addClauseSymsToMap(const ExprTy &item,
+ const ClauseTy *node) {
+ // Nothing to do for expressions.
+}
+
+template <typename C, typename H>
+void ConstructDecompositionT<C, H>::addClauseSymsToMap(
+ const tomp::clause::MapT<TypeTy, IdTy, ExprTy> &item,
+ const ClauseTy *node) {
+ auto &objects = std::get<tomp::ObjectListT<IdTy, ExprTy>>(item.t);
+ addClauseSymsToMap(objects, node);
+ for (auto &object : objects) {
+ if (auto base = helper.getBaseObject(object))
+ mapBases.insert(base->id());
+ }
+}
+
+template <typename C, typename H>
+template <typename U>
+void ConstructDecompositionT<C, H>::addClauseSymsToMap(
+ const std::optional<U> &item, const ClauseTy *node) {
+ if (item)
+ addClauseSymsToMap(*item, node);
+}
+
+template <typename C, typename H>
+template <typename U>
+void ConstructDecompositionT<C, H>::addClauseSymsToMap(
+ const tomp::ListT<U> &item, const ClauseTy *node) {
+ for (auto &s : item)
+ addClauseSymsToMap(s, node);
+}
+
+template <typename C, typename H>
+template <typename... U, size_t... Is>
+void ConstructDecompositionT<C, H>::addClauseSymsToMap(
+ const std::tuple<U...> &item, const ClauseTy *node,
+ std::index_sequence<Is...>) {
+ (void)node; // Silence strange warning from GCC.
+ (addClauseSymsToMap(std::get<Is>(item), node), ...);
+}
+
+template <typename C, typename H>
+template <typename U>
+std::enable_if_t<std::is_enum_v<llvm::remove_cvref_t<U>>, void>
+ConstructDecompositionT<C, H>::addClauseSymsToMap(U &&item,
+ const ClauseTy *node) {
+ // Nothing to do for enums.
+}
+
+template <typename C, typename H>
+template <typename U>
+std::enable_if_t<llvm::remove_cvref_t<U>::EmptyTrait::value, void>
+ConstructDecompositionT<C, H>::addClauseSymsToMap(U &&item,
+ const ClauseTy *node) {
+ // Nothing to do for an empty class.
+}
+
+template <typename C, typename H>
+template <typename U>
+std::enable_if_t<llvm::remove_cvref_t<U>::IncompleteTrait::value, void>
+ConstructDecompositionT<C, H>::addClauseSymsToMap(U &&item,
+ const ClauseTy *node) {
+ // Nothing to do for an incomplete class (they're empty).
+}
+
+template <typename C, typename H>
+template <typename U>
+std::enable_if_t<llvm::remove_cvref_t<U>::WrapperTrait::value, void>
+ConstructDecompositionT<C, H>::addClauseSymsToMap(U &&item,
+ const ClauseTy *node) {
+ addClauseSymsToMap(item.v, node);
+}
+
+template <typename C, typename H>
+template <typename U>
+std::enable_if_t<llvm::remove_cvref_t<U>::TupleTrait::value, void>
+ConstructDecompositionT<C, H>::addClauseSymsToMap(U &&item,
+ const ClauseTy *node) {
+ constexpr size_t tuple_size =
+ std::tuple_size_v<llvm::remove_cvref_t<decltype(item.t)>>;
+ addClauseSymsToMap(item.t, node, std::make_index_sequence<tuple_size>{});
+}
+
+template <typename C, typename H>
+template <typename U>
+std::enable_if_t<llvm::remove_cvref_t<U>::UnionTrait::value, void>
+ConstructDecompositionT<C, H>::addClauseSymsToMap(U &&item,
+ const ClauseTy *node) {
+ std::visit([&](auto &&s) { addClauseSymsToMap(s, node); }, item.u);
+}
+
+// Apply a clause to the only directive that allows it. If there are no
+// directives that allow it, or if there is more that one, do not apply
+// anything and return false, otherwise return true.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyToUnique(const ClauseTy *node) {
+ auto unique = detail::find_unique(leafs, [=](const auto &dirInfo) {
+ return llvm::omp::isAllowedClauseForDirective(dirInfo.id, node->id,
+ version);
+ });
+
+ if (unique != leafs.end()) {
+ unique->clauses.push_back(node);
+ return true;
+ }
+ return false;
+}
+
+// Apply a clause to the first directive in given range that allows it.
+// If such a directive does not exist, return false, otherwise return true.
+template <typename C, typename H>
+template <typename Iterator>
+bool ConstructDecompositionT<C, H>::applyToFirst(
+ const ClauseTy *node, llvm::iterator_range<Iterator> range) {
+ if (range.empty())
+ return false;
+
+ for (auto &leaf : range) {
+ if (!llvm::omp::isAllowedClauseForDirective(leaf.id, node->id, version))
+ continue;
+ leaf.clauses.push_back(node);
+ return true;
+ }
+ return false;
+}
+
+// Apply a clause to the innermost directive that allows it. If such a
+// directive does not exist, return false, otherwise return true.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyToInnermost(const ClauseTy *node) {
+ return applyToFirst(node, llvm::reverse(leafs));
+}
+
+// Apply a clause to the outermost directive that allows it. If such a
+// directive does not exist, return false, otherwise return true.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyToOutermost(const ClauseTy *node) {
+ return applyToFirst(node, llvm::iterator_range(leafs));
+}
+
+template <typename C, typename H>
+template <typename Predicate>
+bool ConstructDecompositionT<C, H>::applyIf(const ClauseTy *node,
+ Predicate shouldApply) {
+ bool applied = false;
+ for (auto &leaf : leafs) {
+ if (!llvm::omp::isAllowedClauseForDirective(leaf.id, node->id, version))
+ continue;
+ if (!shouldApply(leaf))
+ continue;
+ leaf.clauses.push_back(node);
+ applied = true;
+ }
+
+ return applied;
+}
+
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyToAll(const ClauseTy *node) {
+ return applyIf(node, [](auto) { return true; });
+}
+
+template <typename C, typename H>
+template <typename Clause>
+bool ConstructDecompositionT<C, H>::applyClause(Clause &&clause,
+ const ClauseTy *node) {
+ // The default behavior is to find the unique directive to which the
+ // given clause may be applied. If there are no such directives, or
+ // if there are multiple ones, flag an error.
+ // From "OpenMP Application Programming Interface", Version 5.2:
+ // S Some clauses are permitted only on a single leaf construct of the
+ // S combined or composite construct, in which case the effect is as if
+ // S the clause is applied to that specific construct. (p339, 31-33)
+ if (applyToUnique(node))
+ return true;
+
+ return false;
+}
+
+// COLLAPSE
+// [5.2:93:20-21]
+// Directives: distribute, do, for, loop, simd, taskloop
+//
+// [5.2:339:35]
+// (35) The collapse clause is applied once to the combined or composite
+// construct.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyClause(
+ const tomp::clause::CollapseT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *node) {
+ // Apply "collapse" to the innermost directive. If it's not one that
+ // allows it flag an error.
+ if (!leafs.empty()) {
+ auto &last = leafs.back();
+
+ if (llvm::omp::isAllowedClauseForDirective(last.id, node->id, version)) {
+ last.clauses.push_back(node);
+ return true;
+ }
+ }
+
+ return false;
+}
+
+// PRIVATE
+// [5.2:111:5-7]
+// Directives: distribute, do, for, loop, parallel, scope, sections, simd,
+// single, target, task, taskloop, teams
+//
+// [5.2:340:1-2]
+// (1) The effect of the 1 private clause is as if it is applied only to the
+// innermost leaf construct that permits it.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyClause(
+ const tomp::clause::PrivateT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *node) {
+ return applyToInnermost(node);
+}
+
+// FIRSTPRIVATE
+// [5.2:112:5-7]
+// Directives: distribute, do, for, parallel, scope, sections, single, target,
+// task, taskloop, teams
+//
+// [5.2:340:3-20]
+// (3) The effect of the firstprivate clause is as if it is applied to one or
+// more leaf constructs as follows:
+// (5) To the distribute construct if it is among the constituent constructs;
+// (6) To the teams construct if it is among the constituent constructs and the
+// distribute construct is not;
+// (8) To a worksharing construct that accepts the clause if one is among the
+// constituent constructs;
+// (9) To the taskloop construct if it is among the constituent constructs;
+// (10) To the parallel construct if it is among the constituent constructs and
+// neither a taskloop construct nor a worksharing construct that accepts
+// the clause is among them;
+// (12) To the target construct if it is among the constituent constructs and
+// the same list item neither appears in a lastprivate clause nor is the
+// base variable or base pointer of a list item that appears in a map
+// clause.
+//
+// (15) If the parallel construct is among the constituent constructs and the
+// effect is not as if the firstprivate clause is applied to it by the above
+// rules, then the effect is as if the shared clause with the same list item is
+// applied to the parallel construct.
+// (17) If the teams construct is among the constituent constructs and the
+// effect is not as if the firstprivate clause is applied to it by the above
+// rules, then the effect is as if the shared clause with the same list item is
+// applied to the teams construct.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyClause(
+ const tomp::clause::FirstprivateT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *node) {
+ bool applied = false;
+
+ // [5.2:340:3-6]
+ auto dirDistribute = findDirective(llvm::omp::OMPD_distribute);
+ auto dirTeams = findDirective(llvm::omp::OMPD_teams);
+ if (dirDistribute != nullptr) {
+ dirDistribute->clauses.push_back(node);
+ applied = true;
+ // [5.2:340:17]
+ if (dirTeams != nullptr) {
+ auto *shared = makeClause(
+ llvm::omp::Clause::OMPC_shared,
+ tomp::clause::SharedT<TypeTy, IdTy, ExprTy>{/*List=*/clause.v});
+ dirTeams->clauses.push_back(shared);
+ }
+ } else if (dirTeams != nullptr) {
+ dirTeams->clauses.push_back(node);
+ applied = true;
+ }
+
+ // [5.2:340:8]
+ auto findWorksharing = [&]() {
+ auto worksharing = getWorksharing();
+ for (auto &leaf : leafs) {
+ auto found = llvm::find(worksharing, leaf.id);
+ if (found != std::end(worksharing))
+ return &leaf;
+ }
+ return static_cast<typename decltype(leafs)::value_type *>(nullptr);
+ };
+
+ auto dirWorksharing = findWorksharing();
+ if (dirWorksharing != nullptr) {
+ dirWorksharing->clauses.push_back(node);
+ applied = true;
+ }
+
+ // [5.2:340:9]
+ auto dirTaskloop = findDirective(llvm::omp::OMPD_taskloop);
+ if (dirTaskloop != nullptr) {
+ dirTaskloop->clauses.push_back(node);
+ applied = true;
+ }
+
+ // [5.2:340:10]
+ auto dirParallel = findDirective(llvm::omp::OMPD_parallel);
+ if (dirParallel != nullptr) {
+ if (dirTaskloop == nullptr && dirWorksharing == nullptr) {
+ dirParallel->clauses.push_back(node);
+ applied = true;
+ } else {
+ // [5.2:340:15]
+ auto *shared = makeClause(
+ llvm::omp::Clause::OMPC_shared,
+ tomp::clause::SharedT<TypeTy, IdTy, ExprTy>{/*List=*/clause.v});
+ dirParallel->clauses.push_back(shared);
+ }
+ }
+
+ // [5.2:340:12]
+ auto inLastprivate = [&](const ObjectTy &object) {
+ if (ClauseSet *set = findClausesWith(object)) {
+ return llvm::find_if(*set, [](const ClauseTy *c) {
+ return c->id == llvm::omp::Clause::OMPC_lastprivate;
+ }) != set->end();
+ }
+ return false;
+ };
+
+ auto dirTarget = findDirective(llvm::omp::OMPD_target);
+ if (dirTarget != nullptr) {
+ tomp::ObjectListT<IdTy, ExprTy> objects;
+ llvm::copy_if(
+ clause.v, std::back_inserter(objects), [&](const ObjectTy &object) {
+ return !inLastprivate(object) && !mapBases.count(object.id());
+ });
+ if (!objects.empty()) {
+ auto *firstp = makeClause(
+ llvm::omp::Clause::OMPC_firstprivate,
+ tomp::clause::FirstprivateT<TypeTy, IdTy, ExprTy>{/*List=*/objects});
+ dirTarget->clauses.push_back(firstp);
+ applied = true;
+ }
+ }
+
+ // "task" is not handled by any of the cases above.
+ if (auto dirTask = findDirective(llvm::omp::OMPD_task)) {
+ dirTask->clauses.push_back(node);
+ applied = true;
+ }
+
+ return applied;
+}
+
+// LASTPRIVATE
+// [5.2:115:7-8]
+// Directives: distribute, do, for, loop, sections, simd, taskloop
+//
+// [5.2:340:21-30]
+// (21) The effect of the lastprivate clause is as if it is applied to all leaf
+// constructs that permit the clause.
+// (22) If the parallel construct is among the constituent constructs and the
+// list item is not also specified in the firstprivate clause, then the effect
+// of the lastprivate clause is as if the shared clause with the same list item
+// is applied to the parallel construct.
+// (24) If the teams construct is among the constituent constructs and the list
+// item is not also specified in the firstprivate clause, then the effect of the
+// lastprivate clause is as if the shared clause with the same list item is
+// applied to the teams construct.
+// (27) If the target construct is among the constituent constructs and the list
+// item is not the base variable or base pointer of a list item that appears in
+// a map clause, the effect of the lastprivate clause is as if the same list
+// item appears in a map clause with a map-type of tofrom.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyClause(
+ const tomp::clause::LastprivateT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *node) {
+ bool applied = false;
+
+ // [5.2:340:21]
+ applied = applyToAll(node);
+ if (!applied)
+ return false;
+
+ auto inFirstprivate = [&](const ObjectTy &object) {
+ if (ClauseSet *set = findClausesWith(object)) {
+ return llvm::find_if(*set, [](const ClauseTy *c) {
+ return c->id == llvm::omp::Clause::OMPC_firstprivate;
+ }) != set->end();
+ }
+ return false;
+ };
+
+ auto &objects = std::get<tomp::ObjectListT<IdTy, ExprTy>>(clause.t);
+
+ // Prepare list of objects that could end up in a "shared" clause.
+ tomp::ObjectListT<IdTy, ExprTy> sharedObjects;
+ llvm::copy_if(
+ objects, std::back_inserter(sharedObjects),
+ [&](const ObjectTy &object) { return !inFirstprivate(object); });
+
+ if (!sharedObjects.empty()) {
+ // [5.2:340:22]
+ if (auto dirParallel = findDirective(llvm::omp::OMPD_parallel)) {
+ auto *shared = makeClause(
+ llvm::omp::Clause::OMPC_shared,
+ tomp::clause::SharedT<TypeTy, IdTy, ExprTy>{/*List=*/sharedObjects});
+ dirParallel->clauses.push_back(shared);
+ applied = true;
+ }
+
+ // [5.2:340:24]
+ if (auto dirTeams = findDirective(llvm::omp::OMPD_teams)) {
+ auto *shared = makeClause(
+ llvm::omp::Clause::OMPC_shared,
+ tomp::clause::SharedT<TypeTy, IdTy, ExprTy>{/*List=*/sharedObjects});
+ dirTeams->clauses.push_back(shared);
+ applied = true;
+ }
+ }
+
+ // [5.2:340:27]
+ if (auto dirTarget = findDirective(llvm::omp::OMPD_target)) {
+ tomp::ObjectListT<IdTy, ExprTy> tofrom;
+ llvm::copy_if(
+ objects, std::back_inserter(tofrom),
+ [&](const ObjectTy &object) { return !mapBases.count(object.id()); });
+
+ if (!tofrom.empty()) {
+ using MapType =
+ typename tomp::clause::MapT<TypeTy, IdTy, ExprTy>::MapType;
+ auto *map =
+ makeClause(llvm::omp::Clause::OMPC_map,
+ tomp::clause::MapT<TypeTy, IdTy, ExprTy>{
+ {/*MapType=*/MapType::Tofrom,
+ /*MapTypeModifier=*/std::nullopt,
+ /*Mapper=*/std::nullopt, /*Iterator=*/std::nullopt,
+ /*LocatorList=*/std::move(tofrom)}});
+ dirTarget->clauses.push_back(map);
+ applied = true;
+ }
+ }
+
+ return applied;
+}
+
+// SHARED
+// [5.2:110:5-6]
+// Directives: parallel, task, taskloop, teams
+//
+// [5.2:340:31-32]
+// (31) The effect of the shared, default, thread_limit, or order clause is as
+// if it is applied to all leaf constructs that permit the clause.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyClause(
+ const tomp::clause::SharedT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *node) {
+ // [5.2:340:31]
+ return applyToAll(node);
+}
+
+// DEFAULT
+// [5.2:109:5-6]
+// Directives: parallel, task, taskloop, teams
+//
+// [5.2:340:31-32]
+// (31) The effect of the shared, default, thread_limit, or order clause is as
+// if it is applied to all leaf constructs that permit the clause.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyClause(
+ const tomp::clause::DefaultT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *node) {
+ // [5.2:340:31]
+ return applyToAll(node);
+}
+
+// THREAD_LIMIT
+// [5.2:277:14-15]
+// Directives: target, teams
+//
+// [5.2:340:31-32]
+// (31) The effect of the shared, default, thread_limit, or order clause is as
+// if it is applied to all leaf constructs that permit the clause.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyClause(
+ const tomp::clause::ThreadLimitT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *node) {
+ // [5.2:340:31]
+ return applyToAll(node);
+}
+
+// ORDER
+// [5.2:234:3-4]
+// Directives: distribute, do, for, loop, simd
+//
+// [5.2:340:31-32]
+// (31) The effect of the shared, default, thread_limit, or order clause is as
+// if it is applied to all leaf constructs that permit the clause.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyClause(
+ const tomp::clause::OrderT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *node) {
+ // [5.2:340:31]
+ return applyToAll(node);
+}
+
+// ALLOCATE
+// [5.2:178:7-9]
+// Directives: allocators, distribute, do, for, parallel, scope, sections,
+// single, target, task, taskgroup, taskloop, teams
+//
+// [5.2:340:33-35]
+// (33) The effect of the allocate clause is as if it is applied to all leaf
+// constructs that permit the clause and to which a data-sharing attribute
+// clause that may create a private copy of the same list item is applied.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyClause(
+ const tomp::clause::AllocateT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *node) {
+ // This one needs to be applied at the end, once we know which clauses are
+ // assigned to which leaf constructs.
+
+ // [5.2:340:33]
+ auto canMakePrivateCopy = [](llvm::omp::Clause id) {
+ switch (id) {
+ case llvm::omp::Clause::OMPC_firstprivate:
+ case llvm::omp::Clause::OMPC_lastprivate:
+ case llvm::omp::Clause::OMPC_private:
+ return true;
+ default:
+ return false;
+ }
+ };
+
+ bool applied = applyIf(node, [&](const auto &leaf) {
+ return llvm::any_of(leaf.clauses, [&](const ClauseTy *n) {
+ return canMakePrivateCopy(n->id);
+ });
+ });
+
+ return applied;
+}
+
+// REDUCTION
+// [5.2:134:17-18]
+// Directives: do, for, loop, parallel, scope, sections, simd, taskloop, teams
+//
+// [5.2:340:36-37], [5.2:341:1-13]
+// (36) The effect of the reduction clause is as if it is applied to all leaf
+// constructs that permit the clause, except for the following constructs:
+// (1) The parallel construct, when combined with the sections,
+// worksharing-loop, loop, or taskloop construct; and
+// (3) The teams construct, when combined with the loop construct.
+// (4) For the parallel and teams constructs above, the effect of the reduction
+// clause instead is as if each list item or, for any list item that is an array
+// item, its corresponding base array or base pointer appears in a shared clause
+// for the construct.
+// (6) If the task reduction-modifier is specified, the effect is as if it only
+// modifies the behavior of the reduction clause on the innermost leaf construct
+// that accepts the modifier (see Section 5.5.8).
+// (8) If the inscan reduction-modifier is specified, the effect is as if it
+// modifies the behavior of the reduction clause on all constructs of the
+// combined construct to which the clause is applied and that accept the
+// modifier.
+// (10) If a list item in a reduction clause on a combined target construct does
+// not have the same base variable or base pointer as a list item in a map
+// clause on the construct, then the effect is as if the list item in the
+// reduction clause appears as a list item in a map clause with a map-type of
+// tofrom.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyClause(
+ const tomp::clause::ReductionT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *node) {
+ using ReductionTy = tomp::clause::ReductionT<TypeTy, IdTy, ExprTy>;
+
+ // [5.2:340:36], [5.2:341:1], [5.2:341:3]
+ bool applyToParallel = true, applyToTeams = true;
+
+ auto dirParallel = findDirective(llvm::omp::Directive::OMPD_parallel);
+ if (dirParallel) {
+ auto exclusions = llvm::concat<const llvm::omp::Directive>(
+ getWorksharingLoop(), tomp::ListT<llvm::omp::Directive>{
+ llvm::omp::Directive::OMPD_loop,
+ llvm::omp::Directive::OMPD_sections,
+ llvm::omp::Directive::OMPD_taskloop,
+ });
+ auto present = [&](llvm::omp::Directive id) {
+ return findDirective(id) != nullptr;
+ };
+
+ if (llvm::any_of(exclusions, present))
+ applyToParallel = false;
+ }
+
+ auto dirTeams = findDirective(llvm::omp::Directive::OMPD_teams);
+ if (dirTeams) {
+ // The only exclusion is OMPD_loop.
+ if (findDirective(llvm::omp::Directive::OMPD_loop))
+ applyToTeams = false;
+ }
+
+ using ReductionModifier = typename ReductionTy::ReductionModifier;
+ using ReductionIdentifiers = typename ReductionTy::ReductionIdentifiers;
+
+ auto &objects = std::get<tomp::ObjectListT<IdTy, ExprTy>>(clause.t);
+ auto &modifier = std::get<std::optional<ReductionModifier>>(clause.t);
+
+ // Apply the reduction clause first to all directives according to the spec.
+ // If the reduction was applied at least once, proceed with the data sharing
+ // side-effects.
+ bool applied = false;
+
+ // [5.2:341:6], [5.2:341:8]
+ auto isValidModifier = [](llvm::omp::Directive dir, ReductionModifier mod,
+ bool alreadyApplied) {
+ switch (mod) {
+ case ReductionModifier::Inscan:
+ // According to [5.2:135:11-13], "inscan" only applies to
+ // worksharing-loop, worksharing-loop-simd, or "simd" constructs.
+ return dir == llvm::omp::Directive::OMPD_simd ||
+ llvm::is_contained(getWorksharingLoop(), dir);
+ case ReductionModifier::Task:
+ if (alreadyApplied)
+ return false;
+ // According to [5.2:135:16-18], "task" only applies to "parallel" and
+ // worksharing constructs.
+ return dir == llvm::omp::Directive::OMPD_parallel ||
+ llvm::is_contained(getWorksharing(), dir);
+ case ReductionModifier::Default:
+ return true;
+ }
+ llvm_unreachable("Unexpected modifier");
+ };
+
+ auto *unmodified = makeClause(
+ llvm::omp::Clause::OMPC_reduction,
+ ReductionTy{
+ {/*ReductionModifier=*/std::nullopt,
+ /*ReductionIdentifiers=*/std::get<ReductionIdentifiers>(clause.t),
+ /*List=*/objects}});
+
+ ReductionModifier effective =
+ modifier.has_value() ? *modifier : ReductionModifier::Default;
+ bool effectiveApplied = false;
+ // Walk over the leaf constructs starting from the innermost, and apply
+ // the clause as required by the spec.
+ for (auto &leaf : llvm::reverse(leafs)) {
+ if (!llvm::omp::isAllowedClauseForDirective(leaf.id, node->id, version))
+ continue;
+ if (!applyToParallel && &leaf == dirParallel)
+ continue;
+ if (!applyToTeams && &leaf == dirTeams)
+ continue;
+ // Some form of the clause will be applied past this point.
+ if (isValidModifier(leaf.id, effective, effectiveApplied)) {
+ // Apply clause with modifier.
+ leaf.clauses.push_back(node);
+ effectiveApplied = true;
+ } else {
+ // Apply clause without modifier.
+ leaf.clauses.push_back(unmodified);
+ }
+ applied = true;
+ }
+
+ if (!applied)
+ return false;
+
+ tomp::ObjectListT<IdTy, ExprTy> sharedObjects;
+ llvm::transform(objects, std::back_inserter(sharedObjects),
+ [&](const ObjectTy &object) {
+ auto maybeBase = helper.getBaseObject(object);
+ return maybeBase ? *maybeBase : object;
+ });
+
+ // [5.2:341:4]
+ if (!sharedObjects.empty()) {
+ if (dirParallel && !applyToParallel) {
+ auto *shared = makeClause(
+ llvm::omp::Clause::OMPC_shared,
+ tomp::clause::SharedT<TypeTy, IdTy, ExprTy>{/*List=*/sharedObjects});
+ dirParallel->clauses.push_back(shared);
+ }
+ if (dirTeams && !applyToTeams) {
+ auto *shared = makeClause(
+ llvm::omp::Clause::OMPC_shared,
+ tomp::clause::SharedT<TypeTy, IdTy, ExprTy>{/*List=*/sharedObjects});
+ dirTeams->clauses.push_back(shared);
+ }
+ }
+
+ // [5.2:341:10]
+ auto dirTarget = findDirective(llvm::omp::Directive::OMPD_target);
+ if (dirTarget && leafs.size() > 1) {
+ tomp::ObjectListT<IdTy, ExprTy> tofrom;
+ llvm::copy_if(objects, std::back_inserter(tofrom),
+ [&](const ObjectTy &object) {
+ if (auto maybeBase = helper.getBaseObject(object))
+ return !mapBases.count(maybeBase->id());
+ return !mapBases.count(object.id()); // XXX is this ok?
+ });
+ if (!tofrom.empty()) {
+ using MapType =
+ typename tomp::clause::MapT<TypeTy, IdTy, ExprTy>::MapType;
+ auto *map = makeClause(
+ llvm::omp::Clause::OMPC_map,
+ tomp::clause::MapT<TypeTy, IdTy, ExprTy>{
+ {/*MapType=*/MapType::Tofrom, /*MapTypeModifier=*/std::nullopt,
+ /*Mapper=*/std::nullopt, /*Iterator=*/std::nullopt,
+ /*LocatorList=*/std::move(tofrom)}});
+
+ dirTarget->clauses.push_back(map);
+ applied = true;
+ }
+ }
+
+ return applied;
+}
+
+// IF
+// [5.2:72:7-9]
+// Directives: cancel, parallel, simd, target, target data, target enter data,
+// target exit data, target update, task, taskloop
+//
+// [5.2:72:15-18]
+// (15) For combined or composite constructs, the if clause only applies to the
+// semantics of the construct named in the directive-name-modifier.
+// (16) For a combined or composite construct, if no directive-name-modifier is
+// specified then the if clause applies to all constituent constructs to which
+// an if clause can apply.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyClause(
+ const tomp::clause::IfT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *node) {
+ using DirectiveNameModifier =
+ typename clause::IfT<TypeTy, IdTy, ExprTy>::DirectiveNameModifier;
+ using IfExpression = typename clause::IfT<TypeTy, IdTy, ExprTy>::IfExpression;
+ auto &modifier = std::get<std::optional<DirectiveNameModifier>>(clause.t);
+
+ if (modifier) {
+ llvm::omp::Directive dirId = *modifier;
+ auto *unmodified =
+ makeClause(llvm::omp::Clause::OMPC_if,
+ tomp::clause::IfT<TypeTy, IdTy, ExprTy>{
+ {/*DirectiveNameModifier=*/std::nullopt,
+ /*IfExpression=*/std::get<IfExpression>(clause.t)}});
+
+ if (auto *hasDir = findDirective(dirId)) {
+ hasDir->clauses.push_back(unmodified);
+ return true;
+ }
+ return false;
+ }
+
+ return applyToAll(node);
+}
+
+// LINEAR
+// [5.2:118:1-2]
+// Directives: declare simd, do, for, simd
+//
+// [5.2:341:15-22]
+// (15.1) The effect of the linear clause is as if it is applied to the
+// innermost leaf construct.
+// (15.2) Additionally, if the list item is not the iteration variable of a simd
+// or worksharing-loop SIMD construct, the effect on the outer leaf constructs
+// is as if the list item was specified in firstprivate and lastprivate clauses
+// on the combined or composite construct, with the rules specified above
+// applied.
+// (19) If a list item of the linear clause is the iteration variable of a simd
+// or worksharing-loop SIMD construct and it is not declared in the construct,
+// the effect on the outer leaf constructs is as if the list item was specified
+// in a lastprivate clause on the combined or composite construct with the rules
+// specified above applied.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyClause(
+ const tomp::clause::LinearT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *node) {
+ // [5.2:341:15.1]
+ if (!applyToInnermost(node))
+ return false;
+
+ // [5.2:341:15.2], [5.2:341:19]
+ auto dirSimd = findDirective(llvm::omp::Directive::OMPD_simd);
+ std::optional<ObjectTy> iterVar = helper.getLoopIterVar();
+ const auto &objects = std::get<tomp::ObjectListT<IdTy, ExprTy>>(clause.t);
+
+ // Lists of objects that will be used to construct "firstprivate" and
+ // "lastprivate" clauses.
+ tomp::ObjectListT<IdTy, ExprTy> first, last;
+
+ for (const ObjectTy &object : objects) {
+ last.push_back(object);
+ if (!dirSimd || !iterVar || object.id() != iterVar->id())
+ first.push_back(object);
+ }
+
+ if (!first.empty()) {
+ auto *firstp = makeClause(
+ llvm::omp::Clause::OMPC_firstprivate,
+ tomp::clause::FirstprivateT<TypeTy, IdTy, ExprTy>{/*List=*/first});
+ nodes.push_back(firstp); // Appending to the main clause list.
+ }
+ if (!last.empty()) {
+ auto *lastp =
+ makeClause(llvm::omp::Clause::OMPC_lastprivate,
+ tomp::clause::LastprivateT<TypeTy, IdTy, ExprTy>{
+ {/*LastprivateModifier=*/std::nullopt, /*List=*/last}});
+ nodes.push_back(lastp); // Appending to the main clause list.
+ }
+ return true;
+}
+
+// NOWAIT
+// [5.2:308:11-13]
+// Directives: dispatch, do, for, interop, scope, sections, single, target,
+// target enter data, target exit data, target update, taskwait, workshare
+//
+// [5.2:341:23]
+// (23) The effect of the nowait clause is as if it is applied to the outermost
+// leaf construct that permits it.
+template <typename C, typename H>
+bool ConstructDecompositionT<C, H>::applyClause(
+ const tomp::clause::NowaitT<TypeTy, IdTy, ExprTy> &clause,
+ const ClauseTy *node) {
+ return applyToOutermost(node);
+}
+
+template <typename C, typename H> bool ConstructDecompositionT<C, H>::split() {
+ bool success = true;
+
+ for (llvm::omp::Directive leaf :
+ llvm::omp::getLeafConstructsOrSelf(construct))
+ leafs.push_back(LeafReprInternal{leaf, /*clauses=*/{}});
+
+ for (const ClauseTy *node : nodes)
+ addClauseSymsToMap(*node, node);
+
+ // First we need to apply LINEAR, because it can generate additional
+ // "firstprivate" and "lastprivate" clauses that apply to the combined/
+ // composite construct.
+ // Collect them separately, because they may modify the clause list.
+ llvm::SmallVector<const ClauseTy *> linears;
+ for (const ClauseTy *node : nodes) {
+ if (node->id == llvm::omp::Clause::OMPC_linear)
+ linears.push_back(node);
+ }
+ for (const auto *node : linears) {
+ success = success &&
+ applyClause(std::get<tomp::clause::LinearT<TypeTy, IdTy, ExprTy>>(
+ node->u),
+ node);
+ }
+
+ // "allocate" clauses need to be applied last since they need to see
+ // which directives have data-privatizing clauses.
+ auto skip = [](const ClauseTy *node) {
+ switch (node->id) {
+ case llvm::omp::Clause::OMPC_allocate:
+ case llvm::omp::Clause::OMPC_linear:
+ return true;
+ default:
+ return false;
+ }
+ };
+
+ // Apply (almost) all clauses.
+ for (const ClauseTy *node : nodes) {
+ if (skip(node))
+ continue;
+ success =
+ success &&
+ std::visit([&](auto &&s) { return applyClause(s, node); }, node->u);
+ }
+
+ // Apply "allocate".
+ for (const ClauseTy *node : nodes) {
+ if (node->id != llvm::omp::Clause::OMPC_allocate)
+ continue;
+ success =
+ success &&
+ std::visit([&](auto &&s) { return applyClause(s, node); }, node->u);
+ }
+
+ return success;
+}
+
+} // namespace tomp
+
+#endif // LLVM_FRONTEND_OPENMP_CONSTRUCTDECOMPOSITIONT_H
diff --git a/llvm/unittests/Frontend/CMakeLists.txt b/llvm/unittests/Frontend/CMakeLists.txt
index 3f290b63ba647..85e113816e3bc 100644
--- a/llvm/unittests/Frontend/CMakeLists.txt
+++ b/llvm/unittests/Frontend/CMakeLists.txt
@@ -15,6 +15,7 @@ add_llvm_unittest(LLVMFrontendTests
OpenMPIRBuilderTest.cpp
OpenMPParsingTest.cpp
OpenMPCompositionTest.cpp
+ OpenMPDecompositionTest.cpp
DEPENDS
acc_gen
diff --git a/llvm/unittests/Frontend/OpenMPDecompositionTest.cpp b/llvm/unittests/Frontend/OpenMPDecompositionTest.cpp
new file mode 100644
index 0000000000000..df48e9cc0ff4a
--- /dev/null
+++ b/llvm/unittests/Frontend/OpenMPDecompositionTest.cpp
@@ -0,0 +1,999 @@
+//===- llvm/unittests/Frontend/OpenMPDecompositionTest.cpp ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Frontend/OpenMP/ClauseT.h"
+#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
+#include "llvm/Frontend/OpenMP/OMP.h"
+#include "gtest/gtest.h"
+
+#include <iterator>
+#include <optional>
+#include <sstream>
+#include <string>
+#include <tuple>
+#include <type_traits>
+#include <utility>
+
+// The actual tests start at comment "--- Test" below.
+
+// Create simple instantiations of all clauses to allow manual construction
+// of clauses, and implement emitting of a directive with clauses to a string.
+//
+// The tests then follow the pattern
+// 1. Create a list of clauses.
+// 2. Pass them, together with a construct, to the decomposition class.
+// 3. Extract individual resulting leaf constructs with clauses applied
+// to them.
+// 4. Convert them to strings and compare with expected outputs.
+
+namespace omp {
+struct TypeTy {}; // placeholder
+struct ExprTy {}; // placeholder
+using IdTy = std::string;
+} // namespace omp
+
+namespace tomp::type {
+template <> struct ObjectT<omp::IdTy, omp::ExprTy> {
+ const omp::IdTy &id() const { return name; }
+ const std::optional<omp::ExprTy> ref() const { return omp::ExprTy{}; }
+
+ omp::IdTy name;
+};
+} // namespace tomp::type
+
+namespace omp {
+template <typename ElemTy> using List = tomp::type::ListT<ElemTy>;
+
+using Object = tomp::ObjectT<IdTy, ExprTy>;
+
+namespace clause {
+using DefinedOperator = tomp::type::DefinedOperatorT<IdTy, ExprTy>;
+using ProcedureDesignator = tomp::type::ProcedureDesignatorT<IdTy, ExprTy>;
+using ReductionOperator = tomp::type::ReductionIdentifierT<IdTy, ExprTy>;
+
+using AcqRel = tomp::clause::AcqRelT<TypeTy, IdTy, ExprTy>;
+using Acquire = tomp::clause::AcquireT<TypeTy, IdTy, ExprTy>;
+using AdjustArgs = tomp::clause::AdjustArgsT<TypeTy, IdTy, ExprTy>;
+using Affinity = tomp::clause::AffinityT<TypeTy, IdTy, ExprTy>;
+using Aligned = tomp::clause::AlignedT<TypeTy, IdTy, ExprTy>;
+using Align = tomp::clause::AlignT<TypeTy, IdTy, ExprTy>;
+using Allocate = tomp::clause::AllocateT<TypeTy, IdTy, ExprTy>;
+using Allocator = tomp::clause::AllocatorT<TypeTy, IdTy, ExprTy>;
+using AppendArgs = tomp::clause::AppendArgsT<TypeTy, IdTy, ExprTy>;
+using AtomicDefaultMemOrder =
+ tomp::clause::AtomicDefaultMemOrderT<TypeTy, IdTy, ExprTy>;
+using At = tomp::clause::AtT<TypeTy, IdTy, ExprTy>;
+using Bind = tomp::clause::BindT<TypeTy, IdTy, ExprTy>;
+using Capture = tomp::clause::CaptureT<TypeTy, IdTy, ExprTy>;
+using Collapse = tomp::clause::CollapseT<TypeTy, IdTy, ExprTy>;
+using Compare = tomp::clause::CompareT<TypeTy, IdTy, ExprTy>;
+using Copyin = tomp::clause::CopyinT<TypeTy, IdTy, ExprTy>;
+using Copyprivate = tomp::clause::CopyprivateT<TypeTy, IdTy, ExprTy>;
+using Defaultmap = tomp::clause::DefaultmapT<TypeTy, IdTy, ExprTy>;
+using Default = tomp::clause::DefaultT<TypeTy, IdTy, ExprTy>;
+using Depend = tomp::clause::DependT<TypeTy, IdTy, ExprTy>;
+using Destroy = tomp::clause::DestroyT<TypeTy, IdTy, ExprTy>;
+using Detach = tomp::clause::DetachT<TypeTy, IdTy, ExprTy>;
+using Device = tomp::clause::DeviceT<TypeTy, IdTy, ExprTy>;
+using DeviceType = tomp::clause::DeviceTypeT<TypeTy, IdTy, ExprTy>;
+using DistSchedule = tomp::clause::DistScheduleT<TypeTy, IdTy, ExprTy>;
+using Doacross = tomp::clause::DoacrossT<TypeTy, IdTy, ExprTy>;
+using DynamicAllocators =
+ tomp::clause::DynamicAllocatorsT<TypeTy, IdTy, ExprTy>;
+using Enter = tomp::clause::EnterT<TypeTy, IdTy, ExprTy>;
+using Exclusive = tomp::clause::ExclusiveT<TypeTy, IdTy, ExprTy>;
+using Fail = tomp::clause::FailT<TypeTy, IdTy, ExprTy>;
+using Filter = tomp::clause::FilterT<TypeTy, IdTy, ExprTy>;
+using Final = tomp::clause::FinalT<TypeTy, IdTy, ExprTy>;
+using Firstprivate = tomp::clause::FirstprivateT<TypeTy, IdTy, ExprTy>;
+using From = tomp::clause::FromT<TypeTy, IdTy, ExprTy>;
+using Full = tomp::clause::FullT<TypeTy, IdTy, ExprTy>;
+using Grainsize = tomp::clause::GrainsizeT<TypeTy, IdTy, ExprTy>;
+using HasDeviceAddr = tomp::clause::HasDeviceAddrT<TypeTy, IdTy, ExprTy>;
+using Hint = tomp::clause::HintT<TypeTy, IdTy, ExprTy>;
+using If = tomp::clause::IfT<TypeTy, IdTy, ExprTy>;
+using Inbranch = tomp::clause::InbranchT<TypeTy, IdTy, ExprTy>;
+using Inclusive = tomp::clause::InclusiveT<TypeTy, IdTy, ExprTy>;
+using Indirect = tomp::clause::IndirectT<TypeTy, IdTy, ExprTy>;
+using Init = tomp::clause::InitT<TypeTy, IdTy, ExprTy>;
+using InReduction = tomp::clause::InReductionT<TypeTy, IdTy, ExprTy>;
+using IsDevicePtr = tomp::clause::IsDevicePtrT<TypeTy, IdTy, ExprTy>;
+using Lastprivate = tomp::clause::LastprivateT<TypeTy, IdTy, ExprTy>;
+using Linear = tomp::clause::LinearT<TypeTy, IdTy, ExprTy>;
+using Link = tomp::clause::LinkT<TypeTy, IdTy, ExprTy>;
+using Map = tomp::clause::MapT<TypeTy, IdTy, ExprTy>;
+using Match = tomp::clause::MatchT<TypeTy, IdTy, ExprTy>;
+using Mergeable = tomp::clause::MergeableT<TypeTy, IdTy, ExprTy>;
+using Message = tomp::clause::MessageT<TypeTy, IdTy, ExprTy>;
+using Nocontext = tomp::clause::NocontextT<TypeTy, IdTy, ExprTy>;
+using Nogroup = tomp::clause::NogroupT<TypeTy, IdTy, ExprTy>;
+using Nontemporal = tomp::clause::NontemporalT<TypeTy, IdTy, ExprTy>;
+using Notinbranch = tomp::clause::NotinbranchT<TypeTy, IdTy, ExprTy>;
+using Novariants = tomp::clause::NovariantsT<TypeTy, IdTy, ExprTy>;
+using Nowait = tomp::clause::NowaitT<TypeTy, IdTy, ExprTy>;
+using NumTasks = tomp::clause::NumTasksT<TypeTy, IdTy, ExprTy>;
+using NumTeams = tomp::clause::NumTeamsT<TypeTy, IdTy, ExprTy>;
+using NumThreads = tomp::clause::NumThreadsT<TypeTy, IdTy, ExprTy>;
+using OmpxAttribute = tomp::clause::OmpxAttributeT<TypeTy, IdTy, ExprTy>;
+using OmpxBare = tomp::clause::OmpxBareT<TypeTy, IdTy, ExprTy>;
+using OmpxDynCgroupMem = tomp::clause::OmpxDynCgroupMemT<TypeTy, IdTy, ExprTy>;
+using Ordered = tomp::clause::OrderedT<TypeTy, IdTy, ExprTy>;
+using Order = tomp::clause::OrderT<TypeTy, IdTy, ExprTy>;
+using Partial = tomp::clause::PartialT<TypeTy, IdTy, ExprTy>;
+using Priority = tomp::clause::PriorityT<TypeTy, IdTy, ExprTy>;
+using Private = tomp::clause::PrivateT<TypeTy, IdTy, ExprTy>;
+using ProcBind = tomp::clause::ProcBindT<TypeTy, IdTy, ExprTy>;
+using Read = tomp::clause::ReadT<TypeTy, IdTy, ExprTy>;
+using Reduction = tomp::clause::ReductionT<TypeTy, IdTy, ExprTy>;
+using Relaxed = tomp::clause::RelaxedT<TypeTy, IdTy, ExprTy>;
+using Release = tomp::clause::ReleaseT<TypeTy, IdTy, ExprTy>;
+using ReverseOffload = tomp::clause::ReverseOffloadT<TypeTy, IdTy, ExprTy>;
+using Safelen = tomp::clause::SafelenT<TypeTy, IdTy, ExprTy>;
+using Schedule = tomp::clause::ScheduleT<TypeTy, IdTy, ExprTy>;
+using SeqCst = tomp::clause::SeqCstT<TypeTy, IdTy, ExprTy>;
+using Severity = tomp::clause::SeverityT<TypeTy, IdTy, ExprTy>;
+using Shared = tomp::clause::SharedT<TypeTy, IdTy, ExprTy>;
+using Simdlen = tomp::clause::SimdlenT<TypeTy, IdTy, ExprTy>;
+using Simd = tomp::clause::SimdT<TypeTy, IdTy, ExprTy>;
+using Sizes = tomp::clause::SizesT<TypeTy, IdTy, ExprTy>;
+using TaskReduction = tomp::clause::TaskReductionT<TypeTy, IdTy, ExprTy>;
+using ThreadLimit = tomp::clause::ThreadLimitT<TypeTy, IdTy, ExprTy>;
+using Threads = tomp::clause::ThreadsT<TypeTy, IdTy, ExprTy>;
+using To = tomp::clause::ToT<TypeTy, IdTy, ExprTy>;
+using UnifiedAddress = tomp::clause::UnifiedAddressT<TypeTy, IdTy, ExprTy>;
+using UnifiedSharedMemory =
+ tomp::clause::UnifiedSharedMemoryT<TypeTy, IdTy, ExprTy>;
+using Uniform = tomp::clause::UniformT<TypeTy, IdTy, ExprTy>;
+using Unknown = tomp::clause::UnknownT<TypeTy, IdTy, ExprTy>;
+using Untied = tomp::clause::UntiedT<TypeTy, IdTy, ExprTy>;
+using Update = tomp::clause::UpdateT<TypeTy, IdTy, ExprTy>;
+using UseDeviceAddr = tomp::clause::UseDeviceAddrT<TypeTy, IdTy, ExprTy>;
+using UseDevicePtr = tomp::clause::UseDevicePtrT<TypeTy, IdTy, ExprTy>;
+using UsesAllocators = tomp::clause::UsesAllocatorsT<TypeTy, IdTy, ExprTy>;
+using Use = tomp::clause::UseT<TypeTy, IdTy, ExprTy>;
+using Weak = tomp::clause::WeakT<TypeTy, IdTy, ExprTy>;
+using When = tomp::clause::WhenT<TypeTy, IdTy, ExprTy>;
+using Write = tomp::clause::WriteT<TypeTy, IdTy, ExprTy>;
+} // namespace clause
+
+struct Helper {
+ std::optional<Object> getBaseObject(const Object &object) {
+ return std::nullopt;
+ }
+ std::optional<Object> getLoopIterVar() { return std::nullopt; }
+};
+
+using Clause = tomp::ClauseT<TypeTy, IdTy, ExprTy>;
+using ConstructDecomposition = tomp::ConstructDecompositionT<Clause, Helper>;
+using DirectiveWithClauses = tomp::DirectiveWithClauses<Clause>;
+} // namespace omp
+
+struct StringifyClause {
+ static std::string join(const omp::List<std::string> &Strings) {
+ std::stringstream Stream;
+ for (const auto &[Index, String] : llvm::enumerate(Strings)) {
+ if (Index != 0)
+ Stream << ", ";
+ Stream << String;
+ }
+ return Stream.str();
+ }
+
+ static std::string to_str(llvm::omp::Directive D) {
+ return getOpenMPDirectiveName(D).str();
+ }
+ static std::string to_str(llvm::omp::Clause C) {
+ return getOpenMPClauseName(C).str();
+ }
+ static std::string to_str(const omp::TypeTy &Type) { return "type"; }
+ static std::string to_str(const omp::ExprTy &Expr) { return "expr"; }
+ static std::string to_str(const omp::Object &Obj) { return Obj.id(); }
+
+ template <typename U>
+ static std::enable_if_t<std::is_enum_v<llvm::remove_cvref_t<U>>, std::string>
+ to_str(U &&Item) {
+ return std::to_string(llvm::to_underlying(Item));
+ }
+
+ template <typename U> static std::string to_str(const omp::List<U> &Items) {
+ omp::List<std::string> Names;
+ llvm::transform(Items, std::back_inserter(Names),
+ [](auto &&S) { return to_str(S); });
+ return "(" + join(Names) + ")";
+ }
+
+ template <typename U>
+ static std::string to_str(const std::optional<U> &Item) {
+ if (Item)
+ return to_str(*Item);
+ return "";
+ }
+
+ template <typename... Us, size_t... Is>
+ static std::string to_str(const std::tuple<Us...> &Tuple,
+ std::index_sequence<Is...>) {
+ omp::List<std::string> Strings;
+ (Strings.push_back(to_str(std::get<Is>(Tuple))), ...);
+ return "(" + join(Strings) + ")";
+ }
+
+ template <typename U>
+ static std::enable_if_t<llvm::remove_cvref_t<U>::EmptyTrait::value,
+ std::string>
+ to_str(U &&Item) {
+ return "";
+ }
+
+ template <typename U>
+ static std::enable_if_t<llvm::remove_cvref_t<U>::IncompleteTrait::value,
+ std::string>
+ to_str(U &&Item) {
+ return "";
+ }
+
+ template <typename U>
+ static std::enable_if_t<llvm::remove_cvref_t<U>::WrapperTrait::value,
+ std::string>
+ to_str(U &&Item) {
+ // For a wrapper, stringify the wrappee, and only add parentheses if
+ // there aren't any already.
+ std::string Str = to_str(Item.v);
+ if (!Str.empty()) {
+ if (Str.front() == '(' && Str.back() == ')')
+ return Str;
+ }
+ return "(" + to_str(Item.v) + ")";
+ }
+
+ template <typename U>
+ static std::enable_if_t<llvm::remove_cvref_t<U>::TupleTrait::value,
+ std::string>
+ to_str(U &&Item) {
+ constexpr size_t TupleSize =
+ std::tuple_size_v<llvm::remove_cvref_t<decltype(Item.t)>>;
+ return to_str(Item.t, std::make_index_sequence<TupleSize>{});
+ }
+
+ template <typename U>
+ static std::enable_if_t<llvm::remove_cvref_t<U>::UnionTrait::value,
+ std::string>
+ to_str(U &&Item) {
+ return std::visit([](auto &&S) { return to_str(S); }, Item.u);
+ }
+
+ StringifyClause(const omp::Clause &C)
+ // Rely on content stringification to emit enclosing parentheses.
+ : Str(to_str(C.id) + to_str(C)) {}
+
+ std::string Str;
+};
+
+std::string stringify(const omp::DirectiveWithClauses &DWC) {
+ std::stringstream Stream;
+
+ Stream << getOpenMPDirectiveName(DWC.id).str();
+ for (const omp::Clause &C : DWC.clauses)
+ Stream << ' ' << StringifyClause(C).Str;
+
+ return Stream.str();
+}
+
+// --- Tests ----------------------------------------------------------
+
+namespace {
+using namespace llvm::omp;
+
+class OpenMPDecompositionTest : public testing::Test {
+protected:
+ void SetUp() override {}
+ void TearDown() override {}
+
+ omp::Helper Helper;
+ uint32_t AnyVersion = 999;
+};
+
+// PRIVATE
+// [5.2:111:5-7]
+// Directives: distribute, do, for, loop, parallel, scope, sections, simd,
+// single, target, task, taskloop, teams
+//
+// [5.2:340:1-2]
+// (1) The effect of the 1 private clause is as if it is applied only to the
+// innermost leaf construct that permits it.
+TEST_F(OpenMPDecompositionTest, Private1) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_private, omp::clause::Private{{x}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_parallel_sections,
+ Clauses);
+ ASSERT_EQ(Dec.output.size(), 2u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ ASSERT_EQ(Dir0, "parallel"); // (1)
+ ASSERT_EQ(Dir1, "sections private(x)"); // (1)
+}
+
+TEST_F(OpenMPDecompositionTest, Private2) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_private, omp::clause::Private{{x}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_parallel_masked,
+ Clauses);
+ ASSERT_EQ(Dec.output.size(), 2u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ ASSERT_EQ(Dir0, "parallel private(x)"); // (1)
+ ASSERT_EQ(Dir1, "masked"); // (1)
+}
+
+// FIRSTPRIVATE
+// [5.2:112:5-7]
+// Directives: distribute, do, for, parallel, scope, sections, single, target,
+// task, taskloop, teams
+//
+// [5.2:340:3-20]
+// (3) The effect of the firstprivate clause is as if it is applied to one or
+// more leaf constructs as follows:
+// (5) To the distribute construct if it is among the constituent constructs;
+// (6) To the teams construct if it is among the constituent constructs and the
+// distribute construct is not;
+// (8) To a worksharing construct that accepts the clause if one is among the
+// constituent constructs;
+// (9) To the taskloop construct if it is among the constituent constructs;
+// (10) To the parallel construct if it is among the constituent constructs and
+// neither a taskloop construct nor a worksharing construct that accepts
+// the clause is among them;
+// (12) To the target construct if it is among the constituent constructs and
+// the same list item neither appears in a lastprivate clause nor is the
+// base variable or base pointer of a list item that appears in a map
+// clause.
+//
+// (15) If the parallel construct is among the constituent constructs and the
+// effect is not as if the firstprivate clause is applied to it by the above
+// rules, then the effect is as if the shared clause with the same list item is
+// applied to the parallel construct.
+// (17) If the teams construct is among the constituent constructs and the
+// effect is not as if the firstprivate clause is applied to it by the above
+// rules, then the effect is as if the shared clause with the same list item is
+// applied to the teams construct.
+TEST_F(OpenMPDecompositionTest, Firstprivate1) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_firstprivate, omp::clause::Firstprivate{{x}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_parallel_sections,
+ Clauses);
+ ASSERT_EQ(Dec.output.size(), 2u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ ASSERT_EQ(Dir0, "parallel shared(x)"); // (10), (15)
+ ASSERT_EQ(Dir1, "sections firstprivate(x)"); // (8)
+}
+
+TEST_F(OpenMPDecompositionTest, Firstprivate2) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_firstprivate, omp::clause::Firstprivate{{x}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper,
+ OMPD_target_teams_distribute, Clauses);
+ ASSERT_EQ(Dec.output.size(), 3u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ std::string Dir2 = stringify(Dec.output[2]);
+ ASSERT_EQ(Dir0, "target firstprivate(x)"); // (12)
+ ASSERT_EQ(Dir1, "teams shared(x)"); // (6), (17)
+ ASSERT_EQ(Dir2, "distribute firstprivate(x)"); // (5)
+}
+
+TEST_F(OpenMPDecompositionTest, Firstprivate3) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_firstprivate, omp::clause::Firstprivate{{x}}},
+ {OMPC_lastprivate, omp::clause::Lastprivate{{std::nullopt, {x}}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper,
+ OMPD_target_teams_distribute, Clauses);
+ ASSERT_EQ(Dec.output.size(), 3u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ std::string Dir2 = stringify(Dec.output[2]);
+ ASSERT_EQ(Dir0, "target map(2, , , , (x))"); // (12), (27)
+ ASSERT_EQ(Dir1, "teams shared(x)"); // (6), (17)
+ ASSERT_EQ(Dir2, "distribute firstprivate(x) lastprivate(, (x))"); // (5), (21)
+}
+
+TEST_F(OpenMPDecompositionTest, Firstprivate4) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_firstprivate, omp::clause::Firstprivate{{x}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_target_teams,
+ Clauses);
+ ASSERT_EQ(Dec.output.size(), 2u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ ASSERT_EQ(Dir0, "target firstprivate(x)"); // (12)
+ ASSERT_EQ(Dir1, "teams firstprivate(x)"); // (6)
+}
+
+TEST_F(OpenMPDecompositionTest, Firstprivate5) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_firstprivate, omp::clause::Firstprivate{{x}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper,
+ OMPD_parallel_masked_taskloop, Clauses);
+ ASSERT_EQ(Dec.output.size(), 3u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ std::string Dir2 = stringify(Dec.output[2]);
+ ASSERT_EQ(Dir0, "parallel shared(x)"); // (10)
+ ASSERT_EQ(Dir1, "masked");
+ ASSERT_EQ(Dir2, "taskloop firstprivate(x)"); // (9)
+}
+
+TEST_F(OpenMPDecompositionTest, Firstprivate6) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_firstprivate, omp::clause::Firstprivate{{x}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_parallel_masked,
+ Clauses);
+ ASSERT_EQ(Dec.output.size(), 2u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ ASSERT_EQ(Dir0, "parallel firstprivate(x)"); // (10)
+ ASSERT_EQ(Dir1, "masked");
+}
+
+TEST_F(OpenMPDecompositionTest, Firstprivate7) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_firstprivate, omp::clause::Firstprivate{{x}}},
+ };
+
+ // Composite constructs are still decomposed.
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_teams_distribute,
+ Clauses);
+ ASSERT_EQ(Dec.output.size(), 2u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ ASSERT_EQ(Dir0, "teams shared(x)"); // (17)
+ ASSERT_EQ(Dir1, "distribute firstprivate(x)"); // (5)
+}
+
+// LASTPRIVATE
+// [5.2:115:7-8]
+// Directives: distribute, do, for, loop, sections, simd, taskloop
+//
+// [5.2:340:21-30]
+// (21) The effect of the lastprivate clause is as if it is applied to all leaf
+// constructs that permit the clause.
+// (22) If the parallel construct is among the constituent constructs and the
+// list item is not also specified in the firstprivate clause, then the effect
+// of the lastprivate clause is as if the shared clause with the same list item
+// is applied to the parallel construct.
+// (24) If the teams construct is among the constituent constructs and the list
+// item is not also specified in the firstprivate clause, then the effect of the
+// lastprivate clause is as if the shared clause with the same list item is
+// applied to the teams construct.
+// (27) If the target construct is among the constituent constructs and the list
+// item is not the base variable or base pointer of a list item that appears in
+// a map clause, the effect of the lastprivate clause is as if the same list
+// item appears in a map clause with a map-type of tofrom.
+TEST_F(OpenMPDecompositionTest, Lastprivate1) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_lastprivate, omp::clause::Lastprivate{{std::nullopt, {x}}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_parallel_sections,
+ Clauses);
+ ASSERT_EQ(Dec.output.size(), 2u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ ASSERT_EQ(Dir0, "parallel shared(x)"); // (21), (22)
+ ASSERT_EQ(Dir1, "sections lastprivate(, (x))"); // (21)
+}
+
+TEST_F(OpenMPDecompositionTest, Lastprivate2) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_lastprivate, omp::clause::Lastprivate{{std::nullopt, {x}}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_teams_distribute,
+ Clauses);
+ ASSERT_EQ(Dec.output.size(), 2u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ ASSERT_EQ(Dir0, "teams shared(x)"); // (21), (25)
+ ASSERT_EQ(Dir1, "distribute lastprivate(, (x))"); // (21)
+}
+
+TEST_F(OpenMPDecompositionTest, Lastprivate3) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_lastprivate, omp::clause::Lastprivate{{std::nullopt, {x}}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_target_parallel_do,
+ Clauses);
+ ASSERT_EQ(Dec.output.size(), 3u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ std::string Dir2 = stringify(Dec.output[2]);
+ ASSERT_EQ(Dir0, "target map(2, , , , (x))"); // (21), (27)
+ ASSERT_EQ(Dir1, "parallel shared(x)"); // (22)
+ ASSERT_EQ(Dir2, "do lastprivate(, (x))"); // (21)
+}
+
+// SHARED
+// [5.2:110:5-6]
+// Directives: parallel, task, taskloop, teams
+//
+// [5.2:340:31-32]
+// (31) The effect of the shared, default, thread_limit, or order clause is as
+// if it is applied to all leaf constructs that permit the clause.
+TEST_F(OpenMPDecompositionTest, Shared1) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_shared, omp::clause::Shared{{x}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper,
+ OMPD_parallel_masked_taskloop, Clauses);
+ ASSERT_EQ(Dec.output.size(), 3u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ std::string Dir2 = stringify(Dec.output[2]);
+ ASSERT_EQ(Dir0, "parallel shared(x)"); // (31)
+ ASSERT_EQ(Dir1, "masked"); // (31)
+ ASSERT_EQ(Dir2, "taskloop shared(x)"); // (31)
+}
+
+// DEFAULT
+// [5.2:109:5-6]
+// Directives: parallel, task, taskloop, teams
+//
+// [5.2:340:31-32]
+// (31) The effect of the shared, default, thread_limit, or order clause is as
+// if it is applied to all leaf constructs that permit the clause.
+TEST_F(OpenMPDecompositionTest, Default1) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_default,
+ omp::clause::Default{
+ omp::clause::Default::DataSharingAttribute::Firstprivate}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper,
+ OMPD_parallel_masked_taskloop, Clauses);
+ ASSERT_EQ(Dec.output.size(), 3u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ std::string Dir2 = stringify(Dec.output[2]);
+ ASSERT_EQ(Dir0, "parallel default(0)"); // (31)
+ ASSERT_EQ(Dir1, "masked"); // (31)
+ ASSERT_EQ(Dir2, "taskloop default(0)"); // (31)
+}
+
+// THREAD_LIMIT
+// [5.2:277:14-15]
+// Directives: target, teams
+//
+// [5.2:340:31-32]
+// (31) The effect of the shared, default, thread_limit, or order clause is as
+// if it is applied to all leaf constructs that permit the clause.
+TEST_F(OpenMPDecompositionTest, ThreadLimit1) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_thread_limit, omp::clause::ThreadLimit{omp::ExprTy{}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper,
+ OMPD_target_teams_distribute, Clauses);
+ ASSERT_EQ(Dec.output.size(), 3u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ std::string Dir2 = stringify(Dec.output[2]);
+ ASSERT_EQ(Dir0, "target thread_limit(expr)"); // (31)
+ ASSERT_EQ(Dir1, "teams thread_limit(expr)"); // (31)
+ ASSERT_EQ(Dir2, "distribute"); // (31)
+}
+
+// ORDER
+// [5.2:234:3-4]
+// Directives: distribute, do, for, loop, simd
+//
+// [5.2:340:31-32]
+// (31) The effect of the shared, default, thread_limit, or order clause is as
+// if it is applied to all leaf constructs that permit the clause.
+TEST_F(OpenMPDecompositionTest, Order1) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_order,
+ omp::clause::Order{{omp::clause::Order::OrderModifier::Unconstrained,
+ omp::clause::Order::Ordering::Concurrent}}},
+ };
+
+ omp::ConstructDecomposition Dec(
+ AnyVersion, Helper, OMPD_target_teams_distribute_parallel_for_simd,
+ Clauses);
+ ASSERT_EQ(Dec.output.size(), 6u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ std::string Dir2 = stringify(Dec.output[2]);
+ std::string Dir3 = stringify(Dec.output[3]);
+ std::string Dir4 = stringify(Dec.output[4]);
+ std::string Dir5 = stringify(Dec.output[5]);
+ ASSERT_EQ(Dir0, "target"); // (31)
+ ASSERT_EQ(Dir1, "teams"); // (31)
+ // XXX OMP.td doesn't list "order" as allowed for "distribute"
+ ASSERT_EQ(Dir2, "distribute"); // (31)
+ ASSERT_EQ(Dir3, "parallel"); // (31)
+ ASSERT_EQ(Dir4, "for order(1, 0)"); // (31)
+ ASSERT_EQ(Dir5, "simd order(1, 0)"); // (31)
+}
+
+// ALLOCATE
+// [5.2:178:7-9]
+// Directives: allocators, distribute, do, for, parallel, scope, sections,
+// single, target, task, taskgroup, taskloop, teams
+//
+// [5.2:340:33-35]
+// (33) The effect of the allocate clause is as if it is applied to all leaf
+// constructs that permit the clause and to which a data-sharing attribute
+// clause that may create a private copy of the same list item is applied.
+TEST_F(OpenMPDecompositionTest, Allocate1) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_allocate,
+ omp::clause::Allocate{{std::nullopt, std::nullopt, std::nullopt, {x}}}},
+ {OMPC_private, omp::clause::Private{{x}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_parallel_sections,
+ Clauses);
+ ASSERT_EQ(Dec.output.size(), 2u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ ASSERT_EQ(Dir0, "parallel"); // (33)
+ ASSERT_EQ(Dir1, "sections private(x) allocate(, , , (x))"); // (33)
+}
+
+// REDUCTION
+// [5.2:134:17-18]
+// Directives: do, for, loop, parallel, scope, sections, simd, taskloop, teams
+//
+// [5.2:340-341:36-13]
+// (36) The effect of the reduction clause is as if it is applied to all leaf
+// constructs that permit the clause, except for the following constructs:
+// (1) The parallel construct, when combined with the sections,
+// worksharing-loop, loop, or taskloop construct; and
+// (3) The teams construct, when combined with the loop construct.
+// (4) For the parallel and teams constructs above, the effect of the reduction
+// clause instead is as if each list item or, for any list item that is an array
+// item, its corresponding base array or base pointer appears in a shared clause
+// for the construct.
+// (6) If the task reduction-modifier is specified, the effect is as if it only
+// modifies the behavior of the reduction clause on the innermost leaf construct
+// that accepts the modifier (see Section 5.5.8).
+// (8) If the inscan reduction-modifier is specified, the effect is as if it
+// modifies the behavior of the reduction clause on all constructs of the
+// combined construct to which the clause is applied and that accept the
+// modifier.
+// (10) If a list item in a reduction clause on a combined target construct does
+// not have the same base variable or base pointer as a list item in a map
+// clause on the construct, then the effect is as if the list item in the
+// reduction clause appears as a list item in a map clause with a map-type of
+// tofrom.
+namespace red {
+// Make is easier to construct reduction operators from built-in intrinsics.
+omp::clause::ReductionOperator
+makeOp(omp::clause::DefinedOperator::IntrinsicOperator Op) {
+ return omp::clause::ReductionOperator{omp::clause::DefinedOperator{Op}};
+}
+} // namespace red
+
+TEST_F(OpenMPDecompositionTest, Reduction1) {
+ omp::Object x{"x"};
+ auto Add = red::makeOp(omp::clause::DefinedOperator::IntrinsicOperator::Add);
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_reduction, omp::clause::Reduction{{std::nullopt, {Add}, {x}}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_parallel_sections,
+ Clauses);
+ ASSERT_EQ(Dec.output.size(), 2u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ ASSERT_EQ(Dir0, "parallel shared(x)"); // (36), (1), (4)
+ ASSERT_EQ(Dir1, "sections reduction(, (3), (x))"); // (36)
+}
+
+TEST_F(OpenMPDecompositionTest, Reduction2) {
+ omp::Object x{"x"};
+ auto Add = red::makeOp(omp::clause::DefinedOperator::IntrinsicOperator::Add);
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_reduction, omp::clause::Reduction{{std::nullopt, {Add}, {x}}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_parallel_masked,
+ Clauses);
+ ASSERT_EQ(Dec.output.size(), 2u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ ASSERT_EQ(Dir0, "parallel reduction(, (3), (x))"); // (36), (1), (4)
+ ASSERT_EQ(Dir1, "masked"); // (36)
+}
+
+TEST_F(OpenMPDecompositionTest, Reduction3) {
+ omp::Object x{"x"};
+ auto Add = red::makeOp(omp::clause::DefinedOperator::IntrinsicOperator::Add);
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_reduction, omp::clause::Reduction{{std::nullopt, {Add}, {x}}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_teams_loop, Clauses);
+ ASSERT_EQ(Dec.output.size(), 2u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ ASSERT_EQ(Dir0, "teams shared(x)"); // (36), (3), (4)
+ ASSERT_EQ(Dir1, "loop reduction(, (3), (x))"); // (36)
+}
+
+TEST_F(OpenMPDecompositionTest, Reduction4) {
+ omp::Object x{"x"};
+ auto Add = red::makeOp(omp::clause::DefinedOperator::IntrinsicOperator::Add);
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_reduction, omp::clause::Reduction{{std::nullopt, {Add}, {x}}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper,
+ OMPD_teams_distribute_parallel_for, Clauses);
+ ASSERT_EQ(Dec.output.size(), 4u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ std::string Dir2 = stringify(Dec.output[2]);
+ std::string Dir3 = stringify(Dec.output[3]);
+ ASSERT_EQ(Dir0, "teams reduction(, (3), (x))"); // (36), (3)
+ ASSERT_EQ(Dir1, "distribute"); // (36)
+ ASSERT_EQ(Dir2, "parallel shared(x)"); // (36), (1), (4)
+ ASSERT_EQ(Dir3, "for reduction(, (3), (x))"); // (36)
+}
+
+TEST_F(OpenMPDecompositionTest, Reduction5) {
+ omp::Object x{"x"};
+ auto Add = red::makeOp(omp::clause::DefinedOperator::IntrinsicOperator::Add);
+ auto TaskMod = omp::clause::Reduction::ReductionModifier::Task;
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_reduction, omp::clause::Reduction{{TaskMod, {Add}, {x}}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper,
+ OMPD_teams_distribute_parallel_for, Clauses);
+ ASSERT_EQ(Dec.output.size(), 4u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ std::string Dir2 = stringify(Dec.output[2]);
+ std::string Dir3 = stringify(Dec.output[3]);
+ ASSERT_EQ(Dir0, "teams reduction(, (3), (x))"); // (36), (3), (6)
+ ASSERT_EQ(Dir1, "distribute"); // (36)
+ ASSERT_EQ(Dir2, "parallel shared(x)"); // (36), (1), (4)
+ ASSERT_EQ(Dir3, "for reduction(2, (3), (x))"); // (36), (6)
+}
+
+TEST_F(OpenMPDecompositionTest, Reduction6) {
+ omp::Object x{"x"};
+ auto Add = red::makeOp(omp::clause::DefinedOperator::IntrinsicOperator::Add);
+ auto InscanMod = omp::clause::Reduction::ReductionModifier::Inscan;
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_reduction, omp::clause::Reduction{{InscanMod, {Add}, {x}}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper,
+ OMPD_teams_distribute_parallel_for, Clauses);
+ ASSERT_EQ(Dec.output.size(), 4u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ std::string Dir2 = stringify(Dec.output[2]);
+ std::string Dir3 = stringify(Dec.output[3]);
+ ASSERT_EQ(Dir0, "teams reduction(, (3), (x))"); // (36), (3), (8)
+ ASSERT_EQ(Dir1, "distribute"); // (36)
+ ASSERT_EQ(Dir2, "parallel shared(x)"); // (36), (1), (4)
+ ASSERT_EQ(Dir3, "for reduction(1, (3), (x))"); // (36), (8)
+}
+
+TEST_F(OpenMPDecompositionTest, Reduction7) {
+ omp::Object x{"x"};
+ auto Add = red::makeOp(omp::clause::DefinedOperator::IntrinsicOperator::Add);
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_reduction, omp::clause::Reduction{{std::nullopt, {Add}, {x}}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_target_parallel_do,
+ Clauses);
+ ASSERT_EQ(Dec.output.size(), 3u);
+
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ std::string Dir2 = stringify(Dec.output[2]);
+ // XXX Currently OMP.td allows "reduction" on "target".
+ ASSERT_EQ(Dir0,
+ "target reduction(, (3), (x)) map(2, , , , (x))"); // (36), (10)
+ ASSERT_EQ(Dir1, "parallel shared(x)"); // (36), (1), (4)
+ ASSERT_EQ(Dir2, "do reduction(, (3), (x))"); // (36)
+}
+
+// IF
+// [5.2:72:7-9]
+// Directives: cancel, parallel, simd, target, target data, target enter data,
+// target exit data, target update, task, taskloop
+//
+// [5.2:72:15-18]
+// (15) For combined or composite constructs, the if clause only applies to the
+// semantics of the construct named in the directive-name-modifier.
+// (16) For a combined or composite construct, if no directive-name-modifier is
+// specified then the if clause applies to all constituent constructs to which
+// an if clause can apply.
+TEST_F(OpenMPDecompositionTest, If1) {
+ omp::List<omp::Clause> Clauses{
+ {OMPC_if,
+ omp::clause::If{{llvm::omp::Directive::OMPD_parallel, omp::ExprTy{}}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper,
+ OMPD_target_parallel_for_simd, Clauses);
+ ASSERT_EQ(Dec.output.size(), 4u);
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ std::string Dir2 = stringify(Dec.output[2]);
+ std::string Dir3 = stringify(Dec.output[3]);
+ ASSERT_EQ(Dir0, "target"); // (15)
+ ASSERT_EQ(Dir1, "parallel if(, expr)"); // (15)
+ ASSERT_EQ(Dir2, "for"); // (15)
+ ASSERT_EQ(Dir3, "simd"); // (15)
+}
+
+TEST_F(OpenMPDecompositionTest, If2) {
+ omp::List<omp::Clause> Clauses{
+ {OMPC_if, omp::clause::If{{std::nullopt, omp::ExprTy{}}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper,
+ OMPD_target_parallel_for_simd, Clauses);
+ ASSERT_EQ(Dec.output.size(), 4u);
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ std::string Dir2 = stringify(Dec.output[2]);
+ std::string Dir3 = stringify(Dec.output[3]);
+ ASSERT_EQ(Dir0, "target if(, expr)"); // (16)
+ ASSERT_EQ(Dir1, "parallel if(, expr)"); // (16)
+ ASSERT_EQ(Dir2, "for"); // (16)
+ ASSERT_EQ(Dir3, "simd if(, expr)"); // (16)
+}
+
+// LINEAR
+// [5.2:118:1-2]
+// Directives: declare simd, do, for, simd
+//
+// [5.2:341:15-22]
+// (15.1) The effect of the linear clause is as if it is applied to the
+// innermost leaf construct.
+// (15.2) Additionally, if the list item is not the iteration variable of a simd
+// or worksharing-loop SIMD construct, the effect on the outer leaf constructs
+// is as if the list item was specified in firstprivate and lastprivate clauses
+// on the combined or composite construct, with the rules specified above
+// applied.
+// (19) If a list item of the linear clause is the iteration variable of a simd
+// or worksharing-loop SIMD construct and it is not declared in the construct,
+// the effect on the outer leaf constructs is as if the list item was specified
+// in a lastprivate clause on the combined or composite construct with the rules
+// specified above applied.
+TEST_F(OpenMPDecompositionTest, Linear1) {
+ omp::Object x{"x"};
+
+ omp::List<omp::Clause> Clauses{
+ {OMPC_linear,
+ omp::clause::Linear{{std::nullopt, std::nullopt, std::nullopt, {x}}}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_for_simd, Clauses);
+ ASSERT_EQ(Dec.output.size(), 2u);
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ ASSERT_EQ(Dir0, "for firstprivate(x) lastprivate(, (x))"); // (15.1), (15.2)
+ ASSERT_EQ(Dir1, "simd linear(, , , (x)) lastprivate(, (x))"); // (15.1)
+}
+
+// NOWAIT
+// [5.2:308:11-13]
+// Directives: dispatch, do, for, interop, scope, sections, single, target,
+// target enter data, target exit data, target update, taskwait, workshare
+//
+// [5.2:341:23]
+// (23) The effect of the nowait clause is as if it is applied to the outermost
+// leaf construct that permits it.
+TEST_F(OpenMPDecompositionTest, Nowait1) {
+ omp::List<omp::Clause> Clauses{
+ {OMPC_nowait, omp::clause::Nowait{}},
+ };
+
+ omp::ConstructDecomposition Dec(AnyVersion, Helper, OMPD_target_parallel_for,
+ Clauses);
+ ASSERT_EQ(Dec.output.size(), 3u);
+ std::string Dir0 = stringify(Dec.output[0]);
+ std::string Dir1 = stringify(Dec.output[1]);
+ std::string Dir2 = stringify(Dec.output[2]);
+ ASSERT_EQ(Dir0, "target nowait"); // (23)
+ ASSERT_EQ(Dir1, "parallel"); // (23)
+ ASSERT_EQ(Dir2, "for"); // (23)
+}
+} // namespace
More information about the flang-commits
mailing list