[llvm-branch-commits] [flang] [mlir] [Flang][MLIR][OpenMP] Explicitly represent omp.target kernel types (PR #186166)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Mar 12 09:18:20 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-openmp
Author: Sergio Afonso (skatrak)
<details>
<summary>Changes</summary>
Currently, the kernel type (i.e. `generic`, `spmd`, `spmd-no-loop` and `bare`) of an `omp.target` operation is not an explicit attribute of the operation. Rather, this is inferred based on the contents of its region and clauses.
The problems with this approach are that it can be a potentially resource intensive check for large kernels, and misidentifications are prone to happen based on the presence of arbitrary operations from other dialects.
Since the AST already contains the information needed to identify the kernel type in a more reliable manner, this patch moves that responsiblity to the Flang frontend. Other MLIR passes that create `omp.target` operations are updated as well.
One known limitation of this approach is that the MLIR op verifier for `omp.target` can't completely check that the contents of its region are compatible with the declared kernel type without being exposed to the same pattern-matching limitations that this patch is removing. Also, the `TargetOp::getInnermostCapturedOmpOp()` function is maintained but, ideally, a better solution should be implemented to remove its expensive and potentially flaky checks from MLIR.
---
Patch is 384.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/186166.diff
144 Files Affected:
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (-4)
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (-1)
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+397-197)
- (modified) flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp (+5-3)
- (modified) flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp (+34-32)
- (modified) flang/test/Analysis/AliasAnalysis/alias-analysis-omp-target-1.fir (+1-1)
- (modified) flang/test/Analysis/AliasAnalysis/alias-analysis-omp-target-2.fir (+2-2)
- (modified) flang/test/Fir/OpenMP/bounds-generation-for-char-arrays.f90 (+4-4)
- (modified) flang/test/Fir/convert-to-llvm-openmp-and-fir.fir (+20-20)
- (modified) flang/test/Lower/OpenMP/DelayedPrivatization/target-private-allocatable.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/DelayedPrivatization/target-private-implicit-scalar-map-2.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/DelayedPrivatization/target-private-implicit-scalar-map.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/DelayedPrivatization/target-private-simple.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/DelayedPrivatization/target-teams-private-implicit-scalar-map.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/KernelLanguage/bare-clause.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/allocatable-map.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/array-bounds.f90 (+3-3)
- (modified) flang/test/Lower/OpenMP/common-block-map.f90 (+3-3)
- (modified) flang/test/Lower/OpenMP/declare-mapper.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/defaultmap.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/derived-type-allocatable-map.f90 (+4-4)
- (modified) flang/test/Lower/OpenMP/derived-type-map.f90 (+10-10)
- (modified) flang/test/Lower/OpenMP/distribute-parallel-do-simd.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/local-intrinsic-sized-array-map.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/location.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/map-character.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/map-component-ref.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/map-descriptor-deferral.f90 (+4-4)
- (modified) flang/test/Lower/OpenMP/map-mapper.f90 (+2-2)
- (modified) flang/test/Lower/OpenMP/optional-argument-map-2.f90 (+2-2)
- (modified) flang/test/Lower/OpenMP/optional-argument-map-3.f90 (+2-2)
- (modified) flang/test/Lower/OpenMP/target-map-complex.f90 (+2-2)
- (modified) flang/test/Lower/OpenMP/target-parallel-private.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/target-spmd.f90 (+20-20)
- (modified) flang/test/Lower/OpenMP/target-teams-private.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/target.f90 (+16-16)
- (modified) flang/test/Lower/OpenMP/workdistribute-target-teams-clauses.f90 (+1-1)
- (modified) flang/test/Lower/OpenMP/workdistribute.f90 (+1-1)
- (modified) flang/test/Transforms/DoConcurrent/host_eval.f90 (+2-2)
- (modified) flang/test/Transforms/DoConcurrent/local_device.mlir (+1-1)
- (modified) flang/test/Transforms/DoConcurrent/map_shape_info.f90 (+2-2)
- (modified) flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90 (+1-1)
- (modified) flang/test/Transforms/DoConcurrent/non_reference_to_device.f90 (+1-1)
- (modified) flang/test/Transforms/DoConcurrent/reduce_device.mlir (+1-1)
- (modified) flang/test/Transforms/OpenMP/delete-unreachable-targets.mlir (+25-25)
- (modified) flang/test/Transforms/OpenMP/function-filtering-host-ops.mlir (+37-36)
- (modified) flang/test/Transforms/OpenMP/function-filtering.mlir (+7-7)
- (modified) flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir (+2-2)
- (modified) flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir (+2-2)
- (modified) flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir (+2-2)
- (modified) flang/test/Transforms/OpenMP/simd-only.mlir (+2-2)
- (modified) flang/test/Transforms/debug-omp-target-op-1.fir (+1-1)
- (modified) flang/test/Transforms/debug-omp-target-op-2.fir (+1-1)
- (modified) flang/test/Transforms/omp-function-filtering-todo.mlir (+1-1)
- (modified) flang/test/Transforms/omp-map-info-finalization-implicit-field.fir (+2-2)
- (modified) flang/test/Transforms/omp-map-info-finalization.fir (+18-18)
- (modified) flang/test/Transforms/omp-maps-for-privatized-symbols.fir (+2-2)
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h (+6)
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (-26)
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td (+5-1)
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+26-24)
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+111-135)
- (modified) mlir/lib/Dialect/OpenMP/Utils/Utils.cpp (+1-2)
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+9-12)
- (modified) mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir (+4-4)
- (modified) mlir/test/Dialect/OpenMP/canonicalize.mlir (+1-1)
- (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+110-23)
- (modified) mlir/test/Dialect/OpenMP/omp-offload-privatization-prepare-by-value.mlir (+2-2)
- (modified) mlir/test/Dialect/OpenMP/omp-offload-privatization-prepare.mlir (+2-2)
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+84-38)
- (modified) mlir/test/Dialect/OpenMP/stack-to-shared.mlir (+3-3)
- (modified) mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omp-target-call-with-repeated-parameter.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-array-sectioning-host.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-atomic-capture-control-options.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-atomic-update-control-options.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-byref-bycopy-generation-device.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-byref-bycopy-generation-host.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-constant-alloca-raise.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-constant-indexing-device-region.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-debug-147063.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-debug-empty.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-debug-loop-loc.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-debug-map-link-loc.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-debug-nowait.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-debug-var-1.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-debug-var-2.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-debug.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-debug2.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-declare-target-llvm-device.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-declare-target-to-device.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-declare-target-to-host.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-depend-host-only.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-depend.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-device-shared-memory.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-device.mlir (+6-6)
- (modified) mlir/test/Target/LLVMIR/omptarget-fortran-common-block-host.mlir (+2-2)
- (modified) mlir/test/Target/LLVMIR/omptarget-if-nowait.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-if.mlir (+3-3)
- (modified) mlir/test/Target/LLVMIR/omptarget-llvm.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-mapper-combined-entry.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-memcpy-align-metadata.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-multi-block-reduction.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-multi-reduction.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-nested-ptr-record-type-mapping-host.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-nested-record-type-mapping-host.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-nowait-host-only.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-nowait-llvm.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-nowait.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-nullary-record-ptr-member-map.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-overlapping-record-member-map.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-parallel-llvm-debug.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir (+3-3)
- (modified) mlir/test/Target/LLVMIR/omptarget-private-llvm.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-record-type-mapping-host.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-region-host-only.mlir (+2-2)
- (modified) mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir (+2-2)
- (modified) mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-runtimecc.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-target-cpu-features.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction-array-descriptor.mlir (+2-2)
- (modified) mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/omptarget-teams-reduction.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/openmp-data-target-device.mlir (+6-6)
- (modified) mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir (+2-2)
- (modified) mlir/test/Target/LLVMIR/openmp-nested-task-target-parallel.mlir (+8-10)
- (modified) mlir/test/Target/LLVMIR/openmp-private-allloca-hoisting.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/openmp-target-default-as.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir (+2-2)
- (modified) mlir/test/Target/LLVMIR/openmp-target-has-device-addr.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir (+2-2)
- (modified) mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/openmp-target-multiple-private.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/openmp-target-private-allocatable.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/openmp-target-private-shared-mem.mlir (+3-3)
- (modified) mlir/test/Target/LLVMIR/openmp-target-private.mlir (+5-5)
- (modified) mlir/test/Target/LLVMIR/openmp-target-simd-on_device.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/openmp-target-spmd.mlir (+2-2)
- (modified) mlir/test/Target/LLVMIR/openmp-target-wsloop-private.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/openmp-teams-clauses-trunc-ext.mlir (+12-12)
- (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (+2-2)
``````````diff
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index e62395676a696..44a955ae5c4dd 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -206,10 +206,6 @@ getIfClauseOperand(lower::AbstractConverter &converter,
// ClauseProcessor unique clauses
//===----------------------------------------------------------------------===//
-bool ClauseProcessor::processBare(mlir::omp::BareClauseOps &result) const {
- return markClauseOccurrence<omp::clause::OmpxBare>(result.bare);
-}
-
bool ClauseProcessor::processBind(mlir::omp::BindClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Bind>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index da920407b2164..95ac2a767e20d 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -57,7 +57,6 @@ class ClauseProcessor {
: converter(converter), semaCtx(semaCtx), clauses(clauses) {}
// 'Unique' clauses: They can appear at most once in the clause list.
- bool processBare(mlir::omp::BareClauseOps &result) const;
bool processBind(mlir::omp::BindClauseOps &result) const;
bool processCancelDirectiveName(
mlir::omp::CancelDirectiveNameClauseOps &result) const;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 6d93f245228a8..f2609df67eca0 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -67,23 +67,12 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
const ConstructQueue &queue,
ConstructQueue::const_iterator item);
-static void processHostEvalClauses(lower::AbstractConverter &converter,
- semantics::SemanticsContext &semaCtx,
- lower::StatementContext &stmtCtx,
- lower::pft::Evaluation &eval,
- mlir::Location loc);
-
namespace {
/// Structure holding information that is needed to pass host-evaluated
/// information to later lowering stages.
class HostEvalInfo {
public:
- // Allow this function access to private members in order to initialize them.
- friend void ::processHostEvalClauses(lower::AbstractConverter &,
- semantics::SemanticsContext &,
- lower::StatementContext &,
- lower::pft::Evaluation &,
- mlir::Location);
+ friend class HostEvalPatternProcessor;
/// Fill \c vars with values stored in \c ops.
///
@@ -201,6 +190,393 @@ class HostEvalInfo {
llvm::SmallVector<const semantics::Symbol *> iv;
bool loopNestApplied = false, parallelApplied = false;
};
+
+class OpenMPPatternProcessor {
+public:
+ OpenMPPatternProcessor(semantics::SemanticsContext &semaCtx)
+ : semaCtx(semaCtx) {}
+ virtual ~OpenMPPatternProcessor() = default;
+
+ /// Run the pattern from the given evaluation.
+ void process(lower::pft::Evaluation &eval) {
+ dirsToProcess = initialDirsToProcess();
+ processEval(eval);
+ }
+
+protected:
+ /// Returns the set of directives of interest at the beginning of the pattern.
+ virtual OmpDirectiveSet initialDirsToProcess() const = 0;
+
+ /// Processes a single directive and, based on it, returns the set of other
+ /// directives of interest that would be part of the pattern if nested inside
+ /// of it.
+ virtual OmpDirectiveSet processDirective(lower::pft::Evaluation &eval,
+ llvm::omp::Directive dir) = 0;
+
+ /// Obtain the list of clauses of the given OpenMP block or loop construct
+ /// evaluation. If it's not an OpenMP construct, no modifications are made to
+ /// the \c clauses output argument.
+ void extractClauses(lower::pft::Evaluation &eval, List<Clause> &clauses) {
+ const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+ if (!ompEval)
+ return;
+
+ const parser::OmpClauseList *beginClauseList = nullptr;
+ const parser::OmpClauseList *endClauseList = nullptr;
+ common::visit(
+ [&](const auto &construct) {
+ using Type = llvm::remove_cvref_t<decltype(construct)>;
+ if constexpr (std::is_same_v<Type, parser::OmpBlockConstruct> ||
+ std::is_same_v<Type, parser::OpenMPLoopConstruct>) {
+ beginClauseList = &construct.BeginDir().Clauses();
+ if (auto &endSpec = construct.EndDir())
+ endClauseList = &endSpec->Clauses();
+ }
+ },
+ ompEval->u);
+
+ assert(beginClauseList && "expected begin directive");
+ clauses.append(makeClauses(*beginClauseList, semaCtx));
+
+ if (endClauseList)
+ clauses.append(makeClauses(*endClauseList, semaCtx));
+ }
+
+private:
+ /// Decide whether an evaluation must be processed as part of the pattern.
+ ///
+ /// This is the case whenever it's an OpenMP construct and the associated
+ /// directive is part of the current set of directives of interest.
+ bool shouldProcessEval(lower::pft::Evaluation &eval) const {
+ const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+ if (!ompEval)
+ return false;
+
+ return dirsToProcess.test(parser::omp::GetOmpDirectiveName(*ompEval).v);
+ }
+
+ /// Processes an evaluation and, potentially, recursively process a single
+ /// nested evaluation.
+ ///
+ /// For a nested evaluation to be recursively processed, it must be an OpenMP
+ /// construct, have no sibling evaluations and match one of the
+ /// next-directives of interest set returned by a call to \c processDirective
+ /// on the parent evaluation.
+ void processEval(lower::pft::Evaluation &eval) {
+ if (!shouldProcessEval(eval))
+ return;
+
+ const auto &ompEval = eval.get<parser::OpenMPConstruct>();
+ OmpDirectiveSet processNested =
+ processDirective(eval, parser::omp::GetOmpDirectiveName(ompEval).v);
+
+ if (processNested.empty())
+ return;
+
+ if (lower::pft::Evaluation *nestedEval = extractOnlyOmpNestedEval(eval)) {
+ OmpDirectiveSet prevDirs = dirsToProcess;
+ dirsToProcess = processNested;
+ processEval(*nestedEval);
+ dirsToProcess = prevDirs;
+ }
+ }
+
+ /// Return the directive that is immediately nested inside of the given
+ /// \c parent evaluation, if it is its only non-end-statement nested
+ /// evaluation and it represents an OpenMP construct.
+ lower::pft::Evaluation *
+ extractOnlyOmpNestedEval(lower::pft::Evaluation &parent) {
+ if (!parent.hasNestedEvaluations())
+ return nullptr;
+
+ auto &nested = parent.getFirstNestedEvaluation();
+ if (!nested.isA<parser::OpenMPConstruct>())
+ return nullptr;
+
+ for (auto &sibling : parent.getNestedEvaluations())
+ if (&sibling != &nested && !sibling.isEndStmt())
+ return nullptr;
+
+ return &nested;
+ }
+
+protected:
+ semantics::SemanticsContext &semaCtx;
+
+private:
+ OmpDirectiveSet dirsToProcess;
+};
+
+/// Helper pattern to navigate target SPMD patterns.
+class TargetSPMDPatternProcessor : public OpenMPPatternProcessor {
+public:
+ using OpenMPPatternProcessor::OpenMPPatternProcessor;
+ virtual ~TargetSPMDPatternProcessor() = default;
+
+protected:
+ virtual OmpDirectiveSet initialDirsToProcess() const override {
+ return llvm::omp::allTargetSet;
+ }
+
+ virtual OmpDirectiveSet processDirective(lower::pft::Evaluation &,
+ llvm::omp::Directive dir) override {
+ using namespace llvm::omp;
+
+ // The default implementation does nothing, except it returns the allowed
+ // single nested directives for an SPMD kernel. If called by subclasses, it
+ // helps navigate SPMD patterns.
+ //
+ // Patterns considered SPMD:
+ // - target teams distribute parallel do [simd]
+ // - target teams loop
+ // - target parallel do [simd]
+ // - target parallel loop
+ switch (dir) {
+ case OMPD_target:
+ return topTeamsSet | topParallelSet;
+ case OMPD_target_teams:
+ case OMPD_teams:
+ return topDistributeSet | topLoopSet;
+ case OMPD_target_parallel:
+ case OMPD_parallel:
+ return topLoopSet | topDoSet;
+ default:
+ return {};
+ }
+ }
+};
+
+/// Populates the given host eval info structure after processing clauses for
+/// the given \p eval OpenMP target construct, or nested constructs, if these
+/// must be evaluated outside of the target region per the spec.
+///
+/// In particular, this will ensure that in 'target teams' and equivalent nested
+/// constructs, the \c thread_limit, \c num_teams and \c num_threads clauses
+/// will be evaluated in the host. Additionally, loop bounds and steps will also
+/// be evaluated in the host if a 'target teams distribute' or target SPMD
+/// construct is detected (i.e. 'target teams distribute parallel do [simd]',
+/// 'target parallel do [simd]' or equivalent nesting).
+///
+/// The resulting updated \c HostEvalInfo structure is intended to be used to
+/// populate the \c host_eval operands of the associated \c omp.target
+/// operation, and also to be checked and used by later lowering steps to
+/// populate the corresponding operands of the \c omp.teams, \c omp.parallel or
+/// \c omp.loop_nest operations.
+class HostEvalPatternProcessor : public TargetSPMDPatternProcessor {
+public:
+ HostEvalPatternProcessor(lower::AbstractConverter &converter,
+ semantics::SemanticsContext &semaCtx,
+ lower::StatementContext &stmtCtx, mlir::Location loc,
+ HostEvalInfo &hostEvalInfo)
+ : TargetSPMDPatternProcessor(semaCtx), converter(converter),
+ stmtCtx(stmtCtx), loc(loc), hostEvalInfo(hostEvalInfo) {}
+ virtual ~HostEvalPatternProcessor() = default;
+
+protected:
+ virtual OmpDirectiveSet processDirective(lower::pft::Evaluation &eval,
+ llvm::omp::Directive dir) override {
+ using namespace llvm::omp;
+
+ List<lower::omp::Clause> clauses;
+ extractClauses(eval, clauses);
+ ClauseProcessor cp(converter, semaCtx, clauses);
+
+ // Currently, we deal differently with e.g. `target parallel workshare` to
+ // `target parallel` with a single nested `workshare`. The first case would
+ // result in no clauses being evaluated in the host, as there's not a case
+ // for it in the below switch statement. The second case would evaluate
+ // `num_threads` clauses in the host, because `target parallel` could be
+ // followed by a `do` construct, which would make this an SPMD target
+ // region.
+ //
+ // TODO: We don't probably want to have such divergent behavior when dealing
+ // with combined directives. We need to revisit this logic without listing
+ // every possible combined directive containing a clause we'd otherwise
+ // evaluate in the host if the directive was split into its leafs.
+ switch (dir) {
+ case OMPD_teams_distribute_parallel_do:
+ case OMPD_teams_distribute_parallel_do_simd:
+ cp.processThreadLimit(stmtCtx, hostEvalInfo.ops);
+ [[fallthrough]];
+ case OMPD_target_teams_distribute_parallel_do:
+ case OMPD_target_teams_distribute_parallel_do_simd:
+ cp.processNumTeams(stmtCtx, hostEvalInfo.ops);
+ [[fallthrough]];
+ case OMPD_distribute_parallel_do:
+ case OMPD_distribute_parallel_do_simd:
+ case OMPD_target_parallel_do:
+ case OMPD_target_parallel_do_simd:
+ case OMPD_target_parallel_loop:
+ case OMPD_parallel_do:
+ case OMPD_parallel_do_simd:
+ case OMPD_parallel_loop:
+ cp.processNumThreads(stmtCtx, hostEvalInfo.ops);
+ [[fallthrough]];
+ case OMPD_distribute:
+ case OMPD_distribute_simd:
+ case OMPD_do:
+ case OMPD_do_simd:
+ cp.processCollapse(loc, eval, hostEvalInfo.ops, hostEvalInfo.ops,
+ hostEvalInfo.iv);
+ return {};
+
+ case OMPD_teams:
+ cp.processThreadLimit(stmtCtx, hostEvalInfo.ops);
+ [[fallthrough]];
+ case OMPD_target_teams:
+ cp.processNumTeams(stmtCtx, hostEvalInfo.ops);
+ break;
+
+ case OMPD_teams_distribute:
+ case OMPD_teams_distribute_simd:
+ cp.processThreadLimit(stmtCtx, hostEvalInfo.ops);
+ [[fallthrough]];
+ case OMPD_target_teams_distribute:
+ case OMPD_target_teams_distribute_simd:
+ cp.processCollapse(loc, eval, hostEvalInfo.ops, hostEvalInfo.ops,
+ hostEvalInfo.iv);
+ cp.processNumTeams(stmtCtx, hostEvalInfo.ops);
+ return {};
+
+ case OMPD_teams_loop:
+ cp.processThreadLimit(stmtCtx, hostEvalInfo.ops);
+ [[fallthrough]];
+ case OMPD_target_teams_loop:
+ cp.processNumTeams(stmtCtx, hostEvalInfo.ops);
+ [[fallthrough]];
+ case OMPD_loop:
+ cp.processCollapse(loc, eval, hostEvalInfo.ops, hostEvalInfo.ops,
+ hostEvalInfo.iv);
+ return {};
+
+ case OMPD_teams_workdistribute:
+ cp.processThreadLimit(stmtCtx, hostEvalInfo.ops);
+ [[fallthrough]];
+ case OMPD_target_teams_workdistribute:
+ cp.processNumTeams(stmtCtx, hostEvalInfo.ops);
+ break;
+
+ case OMPD_target_parallel:
+ case OMPD_parallel:
+ cp.processNumThreads(stmtCtx, hostEvalInfo.ops);
+ break;
+
+ case OMPD_target:
+ break;
+
+ default:
+ return {};
+ }
+
+ // Visit nested directives as per the SPMD pattern.
+ return TargetSPMDPatternProcessor::processDirective(eval, dir);
+ }
+
+private:
+ lower::AbstractConverter &converter;
+ lower::StatementContext &stmtCtx;
+ mlir::Location loc;
+ HostEvalInfo &hostEvalInfo;
+};
+
+/// Checks target regions and, based on the directives and clauses encountered,
+/// determines its associated kernel type.
+class KernelTypePatternProcessor : protected TargetSPMDPatternProcessor {
+public:
+ KernelTypePatternProcessor(semantics::SemanticsContext &semaCtx,
+ mlir::ModuleOp moduleOp)
+ : TargetSPMDPatternProcessor(semaCtx), moduleOp(moduleOp) {}
+ virtual ~KernelTypePatternProcessor() = default;
+
+ /// Executes the pattern and returns the kernel type of the given target
+ /// region, or \c mlir::omp::TargetExecMode::generic by default for non-target
+ /// evaluations.
+ mlir::omp::TargetExecMode getKernelType(lower::pft::Evaluation &eval) {
+ execMode = mlir::omp::TargetExecMode::generic;
+ process(eval);
+ return execMode;
+ }
+
+protected:
+ virtual OmpDirectiveSet processDirective(lower::pft::Evaluation &eval,
+ llvm::omp::Directive dir) override {
+ using namespace llvm::omp;
+
+ switch (dir) {
+ case OMPD_target:
+ case OMPD_target_parallel:
+ case OMPD_parallel:
+ case OMPD_teams:
+ break;
+ case OMPD_target_teams:
+ if (hasOmpxBareClause(eval)) {
+ execMode = mlir::omp::TargetExecMode::bare;
+ return {};
+ }
+ break;
+ case OMPD_target_teams_distribute_parallel_do:
+ case OMPD_target_teams_distribute_parallel_do_simd:
+ case OMPD_target_teams_loop:
+ case OMPD_target_parallel_do:
+ case OMPD_target_parallel_do_simd:
+ case OMPD_target_parallel_loop:
+ case OMPD_teams_distribute_parallel_do:
+ case OMPD_teams_distribute_parallel_do_simd:
+ case OMPD_teams_loop:
+ case OMPD_distribute_parallel_do:
+ case OMPD_distribute_parallel_do_simd:
+ case OMPD_loop:
+ case OMPD_parallel_do:
+ case OMPD_parallel_do_simd:
+ case OMPD_do:
+ case OMPD_do_simd:
+ execMode = canPromoteSPMDToNoLoop(eval)
+ ? mlir::omp::TargetExecMode::spmd_no_loop
+ : mlir::omp::TargetExecMode::spmd;
+ return {};
+ default:
+ return {};
+ }
+
+ // Visit nested directives as per the SPMD pattern.
+ return TargetSPMDPatternProcessor::processDirective(eval, dir);
+ }
+
+private:
+ bool canPromoteSPMDToNoLoop(lower::pft::Evaluation &eval) {
+ List<lower::omp::Clause> clauses;
+ extractClauses(eval, clauses);
+
+ // First make sure the proper module attributes are present in order to
+ // perform this optimization.
+ auto ompFlags =
+ llvm::cast<mlir::omp::OffloadModuleInterface>(*moduleOp).getFlags();
+ if (!ompFlags || !ompFlags.getAssumeTeamsOversubscription() ||
+ !ompFlags.getAssumeThreadsOversubscription())
+ return false;
+
+ // The num_teams clause can break no-loop assumptions, and reductions are
+ // slower in no-loop mode.
+ return llvm::find_if(clauses, [](const Clause &clause) {
+ return std::holds_alternative<clause::NumTeams>(clause.u) ||
+ std::holds_alternative<clause::Reduction>(clause.u);
+ }) == clauses.end();
+ }
+
+ bool hasOmpxBareClause(lower::pft::Evaluation &eval) {
+ List<lower::omp::Clause> clauses;
+ extractClauses(eval, clauses);
+
+ return llvm::find_if(clauses, [](const Clause &clause) {
+ return std::holds_alternative<clause::OmpxBare>(clause.u);
+ }) != clauses.end();
+ }
+
+private:
+ mlir::ModuleOp moduleOp;
+ mlir::omp::TargetExecMode execMode;
+};
+
} // namespace
/// Stack of \see HostEvalInfo to represent the current nest of \c omp.target
@@ -384,187 +760,6 @@ extractMappedBaseValues(llvm::ArrayRef<mlir::Value> vars,
});
}
-/// Populate the global \see hostEvalInfo after processing clauses for the given
-/// \p eval OpenMP target construct, or nested constructs, if these must be
-/// evaluated outside of the target region per the spec.
-///
-/// In particular, this will ensure that in 'target teams' and equivalent nested
-/// constructs, the \c thread_limit and \c num_teams clauses will be evaluated
-/// in the host. Additionally, loop bounds, steps and the \c num_threads clause
-/// will also be evaluated in the host if a target SPMD construct is detected
-/// (i.e. 'target teams distribute parallel do [simd]' or equivalent nesting).
-///
-/// The result, stored as a global, is intended to be used to populate the \c
-/// host_eval operands of the associated \c omp.target operation, and also to be
-/// checked and used by later lowering steps to populate the corresponding
-/// operands of the \c omp.teams, \c omp.parallel or \c omp.loop_nest
-/// operations.
-static void processHostEvalClauses(lower::AbstractConverter &converter,
- semantics::SemanticsContext &semaCtx,
- lower::StatementContext &stmtCtx,
- lower::pft::Evaluation &eval,
- mlir::Location loc) {
- // Obtain the list of clauses of the given OpenMP block or loop construct
- // evaluation. Other evaluations passed to this lambda keep `clauses`
- // unchanged.
- auto extractClauses = [&semaCtx](lower::pft::Evaluation &eval,
- List<Clause> &clauses) {
- const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
- if (!ompEval)
- return;
-
- const parser::OmpClauseList *beginClauseList = nullptr;
- const parser::OmpClauseList *endClauseList = nullptr;
- common::visit(
- [&](const auto &construct) {
- using Type = llvm::remove_cvref_t<decltype(construct)>;
- if constexpr (std::is_same_v<Type, parser::OmpBlockConstruct> ||
- std::is_same_v<Type, parser::OpenMPLoopConstruct>) {
- beginClauseList = &construct.BeginDir().Clauses();
- if (auto &endSpec = construct.EndDir())
- endClauseList = &endSpec->Clauses();
- }
- },
- ompEval->u);
-
- assert(beginClauseList && "expected begin directive");
- clauses.append(makeClauses(*beginClauseList, semaCtx));
-
- if (endClauseList)
- clauses.append(makeClauses(*endClauseList, semaCtx));
- };
-
- // Return the directive that is immediately nested inside of the given
- // `parent` evaluation, if it is its only non-end-statement nested evaluation
- // and it represents an OpenMP construct.
- auto extractOnlyOmpNestedDir = [](lower::pft::Evaluation &parent)
- ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/186166
More information about the llvm-branch-commits
mailing list