[flang] [llvm] [mlir] [flang][OpenMP][MLIR] Add MLIR op for loop directive (PR #113911)
Kareem Ergawy via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 29 01:11:13 PDT 2024
https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/113911
>From a546224e76bf5595bd20621d3591c94ea253fb06 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Fri, 25 Oct 2024 04:00:52 -0500
Subject: [PATCH 1/2] [flang][OpenMP] Parase `bind` clause for `loop`
direcitve.
Adds parsing for the `bind` clause. The clause was already part of the
`loop` direcitve's definition but parsing was still missing.
---
flang/include/flang/Parser/dump-parse-tree.h | 2 +
flang/include/flang/Parser/parse-tree.h | 7 ++++
.../flang/Semantics/openmp-directive-sets.h | 2 +
flang/lib/Parser/openmp-parsers.cpp | 9 +++++
flang/lib/Semantics/check-omp-structure.cpp | 39 ++++++++++++++++++-
flang/lib/Semantics/check-omp-structure.h | 1 +
flang/lib/Semantics/resolve-directives.cpp | 1 +
.../Parser/OpenMP/target-loop-unparse.f90 | 16 +++++++-
flang/test/Semantics/OpenMP/loop-bind.f90 | 33 ++++++++++++++++
.../Semantics/OpenMP/nested-distribute.f90 | 6 +--
flang/test/Semantics/OpenMP/nested-teams.f90 | 2 +-
llvm/include/llvm/Frontend/OpenMP/OMP.td | 1 +
12 files changed, 111 insertions(+), 8 deletions(-)
create mode 100644 flang/test/Semantics/OpenMP/loop-bind.f90
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 31ad1b7c6ce5b5..ef6a774ec2c3d2 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -551,6 +551,8 @@ class ParseTreeDumper {
NODE_ENUM(OmpGrainsizeClause, Prescriptiveness)
NODE(parser, OmpNumTasksClause)
NODE_ENUM(OmpNumTasksClause, Prescriptiveness)
+ NODE(parser, OmpBindClause)
+ NODE_ENUM(OmpBindClause, Type)
NODE(parser, OmpProcBindClause)
NODE_ENUM(OmpProcBindClause, Type)
NODE_ENUM(OmpReductionClause, ReductionModifier)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 506a470c5557b7..dbdb8cf764341e 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -3714,6 +3714,13 @@ struct OmpNumTasksClause {
std::tuple<std::optional<Prescriptiveness>, ScalarIntExpr> t;
};
+// OMP 5.2 11.7.1 bind-clause ->
+// BIND( PARALLEL | TEAMS | THREAD )
+struct OmpBindClause {
+ ENUM_CLASS(Type, Parallel, Teams, Thread)
+ WRAPPER_CLASS_BOILERPLATE(OmpBindClause, Type);
+};
+
// OpenMP Clauses
struct OmpClause {
UNION_CLASS_BOILERPLATE(OmpClause);
diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index 50d6d5b59ef7dd..da360a34ca86df 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -169,6 +169,7 @@ static const OmpDirectiveSet topTeamsSet{
Directive::OMPD_teams_distribute_parallel_do,
Directive::OMPD_teams_distribute_parallel_do_simd,
Directive::OMPD_teams_distribute_simd,
+ Directive::OMPD_teams_loop,
};
static const OmpDirectiveSet allTeamsSet{
@@ -365,6 +366,7 @@ static const OmpDirectiveSet nestedTeamsAllowedSet{
Directive::OMPD_distribute_parallel_do,
Directive::OMPD_distribute_parallel_do_simd,
Directive::OMPD_distribute_simd,
+ Directive::OMPD_loop,
Directive::OMPD_parallel,
Directive::OMPD_parallel_do,
Directive::OMPD_parallel_do_simd,
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index 3ca4e93a6c9b93..9ce135411e2085 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -427,6 +427,12 @@ TYPE_PARSER(construct<OmpLastprivateClause>(
pure(OmpLastprivateClause::LastprivateModifier::Conditional) / ":"),
Parser<OmpObjectList>{}))
+// OMP 5.2 11.7.1 BIND ( PARALLEL | TEAMS | THREAD )
+TYPE_PARSER(construct<OmpBindClause>(
+ "PARALLEL" >> pure(OmpBindClause::Type::Parallel) ||
+ "TEAMS" >> pure(OmpBindClause::Type::Teams) ||
+ "THREAD" >> pure(OmpBindClause::Type::Thread)))
+
TYPE_PARSER(
"ACQUIRE" >> construct<OmpClause>(construct<OmpClause::Acquire>()) ||
"ACQ_REL" >> construct<OmpClause>(construct<OmpClause::AcqRel>()) ||
@@ -441,6 +447,8 @@ TYPE_PARSER(
"ATOMIC_DEFAULT_MEM_ORDER" >>
construct<OmpClause>(construct<OmpClause::AtomicDefaultMemOrder>(
parenthesized(Parser<OmpAtomicDefaultMemOrderClause>{}))) ||
+ "BIND" >> construct<OmpClause>(construct<OmpClause::Bind>(
+ parenthesized(Parser<OmpBindClause>{}))) ||
"COLLAPSE" >> construct<OmpClause>(construct<OmpClause::Collapse>(
parenthesized(scalarIntConstantExpr))) ||
"COPYIN" >> construct<OmpClause>(construct<OmpClause::Copyin>(
@@ -615,6 +623,7 @@ TYPE_PARSER(sourced(construct<OmpLoopDirective>(first(
"TEAMS DISTRIBUTE SIMD" >>
pure(llvm::omp::Directive::OMPD_teams_distribute_simd),
"TEAMS DISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_teams_distribute),
+ "TEAMS LOOP" >> pure(llvm::omp::Directive::OMPD_teams_loop),
"TILE" >> pure(llvm::omp::Directive::OMPD_tile),
"UNROLL" >> pure(llvm::omp::Directive::OMPD_unroll)))))
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index 599cc61a83bf0a..be8512bb33f1bd 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -363,13 +363,47 @@ void OmpStructureChecker::HasInvalidDistributeNesting(
"region."_err_en_US);
}
}
+void OmpStructureChecker::HasInvalidLoopBinding(
+ const parser::OpenMPLoopConstruct &x) {
+ const auto &beginLoopDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
+ const auto &beginDir{std::get<parser::OmpLoopDirective>(beginLoopDir.t)};
+
+ auto teamsBindingChecker = [&](parser::MessageFixedText msg) {
+ const auto &clauseList{std::get<parser::OmpClauseList>(beginLoopDir.t)};
+ for (const auto &clause : clauseList.v) {
+ if (const auto *bindClause{
+ std::get_if<parser::OmpClause::Bind>(&clause.u)}) {
+ if (bindClause->v.v != parser::OmpBindClause::Type::Teams) {
+ context_.Say(beginDir.source, msg);
+ }
+ }
+ }
+ };
+
+ if (llvm::omp::Directive::OMPD_loop == beginDir.v &&
+ CurrentDirectiveIsNested() &&
+ OmpDirectiveSet{llvm::omp::OMPD_teams, llvm::omp::OMPD_target_teams}.test(
+ GetContextParent().directive)) {
+ teamsBindingChecker(
+ "`BIND(TEAMS)` must be specified since the `LOOP` region is "
+ "strictly nested inside a `TEAMS` region."_err_en_US);
+ }
+
+ if (OmpDirectiveSet{
+ llvm::omp::OMPD_teams_loop, llvm::omp::OMPD_target_teams_loop}
+ .test(beginDir.v)) {
+ teamsBindingChecker(
+ "`BIND(TEAMS)` must be specified since the `LOOP` directive is "
+ "combined with a `TEAMS` construct."_err_en_US);
+ }
+}
void OmpStructureChecker::HasInvalidTeamsNesting(
const llvm::omp::Directive &dir, const parser::CharBlock &source) {
if (!llvm::omp::nestedTeamsAllowedSet.test(dir)) {
context_.Say(source,
- "Only `DISTRIBUTE` or `PARALLEL` regions are allowed to be strictly "
- "nested inside `TEAMS` region."_err_en_US);
+ "Only `DISTRIBUTE`, `PARALLEL`, or `LOOP` regions are allowed to be "
+ "strictly nested inside `TEAMS` region."_err_en_US);
}
}
@@ -532,6 +566,7 @@ void OmpStructureChecker::Enter(const parser::OpenMPLoopConstruct &x) {
CheckLoopItrVariableIsInt(x);
CheckAssociatedLoopConstraints(x);
HasInvalidDistributeNesting(x);
+ HasInvalidLoopBinding(x);
if (CurrentDirectiveIsNested() &&
llvm::omp::topTeamsSet.test(GetContextParent().directive)) {
HasInvalidTeamsNesting(beginDir.v, beginDir.source);
diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h
index 237569bc40c483..b1f19c19bd2375 100644
--- a/flang/lib/Semantics/check-omp-structure.h
+++ b/flang/lib/Semantics/check-omp-structure.h
@@ -148,6 +148,7 @@ class OmpStructureChecker
void HasInvalidTeamsNesting(
const llvm::omp::Directive &dir, const parser::CharBlock &source);
void HasInvalidDistributeNesting(const parser::OpenMPLoopConstruct &x);
+ void HasInvalidLoopBinding(const parser::OpenMPLoopConstruct &x);
// specific clause related
bool ScheduleModifierHasType(const parser::OmpScheduleClause &,
const parser::OmpScheduleModifierType::ModType &);
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 014b7987a658bd..7640e93bc3fa73 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -1655,6 +1655,7 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPLoopConstruct &x) {
case llvm::omp::Directive::OMPD_teams_distribute_parallel_do:
case llvm::omp::Directive::OMPD_teams_distribute_parallel_do_simd:
case llvm::omp::Directive::OMPD_teams_distribute_simd:
+ case llvm::omp::Directive::OMPD_teams_loop:
case llvm::omp::Directive::OMPD_tile:
case llvm::omp::Directive::OMPD_unroll:
PushContext(beginDir.source, beginDir.v);
diff --git a/flang/test/Parser/OpenMP/target-loop-unparse.f90 b/flang/test/Parser/OpenMP/target-loop-unparse.f90
index 3ee2fcef075a37..b2047070496527 100644
--- a/flang/test/Parser/OpenMP/target-loop-unparse.f90
+++ b/flang/test/Parser/OpenMP/target-loop-unparse.f90
@@ -1,6 +1,8 @@
+! RUN: %flang_fc1 -fdebug-unparse -fopenmp -fopenmp-version=50 %s | \
+! RUN: FileCheck --ignore-case %s
-! RUN: %flang_fc1 -fdebug-unparse -fopenmp %s | FileCheck --ignore-case %s
-! RUN: %flang_fc1 -fdebug-dump-parse-tree -fopenmp %s | FileCheck --check-prefix="PARSE-TREE" %s
+! RUN: %flang_fc1 -fdebug-dump-parse-tree -fopenmp -fopenmp-version=50 %s | \
+! RUN: FileCheck --check-prefix="PARSE-TREE" %s
! Check for parsing of loop directive
@@ -14,6 +16,16 @@ subroutine test_loop
j = j + 1
end do
!$omp end loop
+
+ !PARSE-TREE: OmpBeginLoopDirective
+ !PARSE-TREE-NEXT: OmpLoopDirective -> llvm::omp::Directive = loop
+ !PARSE-TREE-NEXT: OmpClauseList -> OmpClause -> Bind -> OmpBindClause -> Type = Thread
+ !CHECK: !$omp loop
+ !$omp loop bind(thread)
+ do i=1,10
+ j = j + 1
+ end do
+ !$omp end loop
end subroutine
subroutine test_target_loop
diff --git a/flang/test/Semantics/OpenMP/loop-bind.f90 b/flang/test/Semantics/OpenMP/loop-bind.f90
new file mode 100644
index 00000000000000..f3aa9d19fe989e
--- /dev/null
+++ b/flang/test/Semantics/OpenMP/loop-bind.f90
@@ -0,0 +1,33 @@
+! RUN: %python %S/../test_errors.py %s %flang -fopenmp -fopenmp-version=50
+
+! OpenMP Version 5.0
+! Check OpenMP construct validity for the following directives:
+! 11.7 Loop directive
+
+program main
+ integer :: i, x
+
+ !$omp teams
+ !ERROR: `BIND(TEAMS)` must be specified since the `LOOP` region is strictly nested inside a `TEAMS` region.
+ !$omp loop bind(thread)
+ do i = 1, 10
+ x = x + 1
+ end do
+ !$omp end loop
+ !$omp end teams
+
+ !ERROR: `BIND(TEAMS)` must be specified since the `LOOP` directive is combined with a `TEAMS` construct.
+ !$omp target teams loop bind(thread)
+ do i = 1, 10
+ x = x + 1
+ end do
+ !$omp end target teams loop
+
+ !ERROR: `BIND(TEAMS)` must be specified since the `LOOP` directive is combined with a `TEAMS` construct.
+ !$omp teams loop bind(thread)
+ do i = 1, 10
+ x = x + 1
+ end do
+ !$omp end teams loop
+
+end program main
diff --git a/flang/test/Semantics/OpenMP/nested-distribute.f90 b/flang/test/Semantics/OpenMP/nested-distribute.f90
index ba8c3bf04b3377..c212763cba1df8 100644
--- a/flang/test/Semantics/OpenMP/nested-distribute.f90
+++ b/flang/test/Semantics/OpenMP/nested-distribute.f90
@@ -21,7 +21,7 @@ program main
!$omp teams
do i = 1, N
- !ERROR: Only `DISTRIBUTE` or `PARALLEL` regions are allowed to be strictly nested inside `TEAMS` region.
+ !ERROR: Only `DISTRIBUTE`, `PARALLEL`, or `LOOP` regions are allowed to be strictly nested inside `TEAMS` region.
!$omp task
do k = 1, N
a = 3.14
@@ -50,7 +50,7 @@ program main
!$omp end parallel
!$omp teams
- !ERROR: Only `DISTRIBUTE` or `PARALLEL` regions are allowed to be strictly nested inside `TEAMS` region.
+ !ERROR: Only `DISTRIBUTE`, `PARALLEL`, or `LOOP` regions are allowed to be strictly nested inside `TEAMS` region.
!$omp target
!ERROR: `DISTRIBUTE` region has to be strictly nested inside `TEAMS` region.
!$omp distribute
@@ -82,7 +82,7 @@ program main
!$omp end target teams
!$omp teams
- !ERROR: Only `DISTRIBUTE` or `PARALLEL` regions are allowed to be strictly nested inside `TEAMS` region.
+ !ERROR: Only `DISTRIBUTE`, `PARALLEL`, or `LOOP` regions are allowed to be strictly nested inside `TEAMS` region.
!$omp task
do k = 1,10
print *, "hello"
diff --git a/flang/test/Semantics/OpenMP/nested-teams.f90 b/flang/test/Semantics/OpenMP/nested-teams.f90
index f3b96b0ab43903..fd2a3a61e34357 100644
--- a/flang/test/Semantics/OpenMP/nested-teams.f90
+++ b/flang/test/Semantics/OpenMP/nested-teams.f90
@@ -59,7 +59,7 @@ program main
!$omp target
!$omp teams
- !ERROR: Only `DISTRIBUTE` or `PARALLEL` regions are allowed to be strictly nested inside `TEAMS` region.
+ !ERROR: Only `DISTRIBUTE`, `PARALLEL`, or `LOOP` regions are allowed to be strictly nested inside `TEAMS` region.
!ERROR: TEAMS region can only be strictly nested within the implicit parallel region or TARGET region
!$omp teams
a = 3.14
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 70179bab475779..feb8eb5c2abf49 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -73,6 +73,7 @@ def OMPC_AtomicDefaultMemOrder : Clause<"atomic_default_mem_order"> {
}
def OMPC_Bind : Clause<"bind"> {
let clangClass = "OMPBindClause";
+ let flangClass = "OmpBindClause";
}
def OMP_CANCELLATION_CONSTRUCT_Parallel : ClauseVal<"parallel", 1, 1> {}
def OMP_CANCELLATION_CONSTRUCT_Loop : ClauseVal<"loop", 2, 1> {}
>From 7582e2b9e37a991cf87b7d6f790c690fb915efae Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Mon, 28 Oct 2024 00:04:15 -0500
Subject: [PATCH 2/2] [flang][OpenMP][MLIR] Add MLIR op for `loop` directive
Adds MLIR op that corresponds to the `loop` directive.
---
llvm/include/llvm/Frontend/OpenMP/OMP.td | 11 +++++
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 25 +++++++++++
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 44 +++++++++++++++++++
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 23 ++++++++++
mlir/test/Dialect/OpenMP/invalid.mlir | 30 +++++++++++++
mlir/test/Dialect/OpenMP/ops.mlir | 40 +++++++++++++++++
6 files changed, 173 insertions(+)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index feb8eb5c2abf49..9668ab1e90ecf3 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -71,10 +71,21 @@ def OMPC_AtomicDefaultMemOrder : Clause<"atomic_default_mem_order"> {
let clangClass = "OMPAtomicDefaultMemOrderClause";
let flangClass = "OmpAtomicDefaultMemOrderClause";
}
+
+def OMP_BIND_parallel : ClauseVal<"parallel",1,1> {}
+def OMP_BIND_teams : ClauseVal<"teams",2,1> {}
+def OMP_BIND_thread : ClauseVal<"thread",3,1> { let isDefault = true; }
def OMPC_Bind : Clause<"bind"> {
let clangClass = "OMPBindClause";
let flangClass = "OmpBindClause";
+ let enumClauseValue = "BindKind";
+ let allowedClauseValues = [
+ OMP_BIND_parallel,
+ OMP_BIND_teams,
+ OMP_BIND_thread
+ ];
}
+
def OMP_CANCELLATION_CONSTRUCT_Parallel : ClauseVal<"parallel", 1, 1> {}
def OMP_CANCELLATION_CONSTRUCT_Loop : ClauseVal<"loop", 2, 1> {}
def OMP_CANCELLATION_CONSTRUCT_Sections : ClauseVal<"sections", 3, 1> {}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 886554f66afffc..855deab94b2f16 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -107,6 +107,31 @@ class OpenMP_CancelDirectiveNameClauseSkip<
def OpenMP_CancelDirectiveNameClause : OpenMP_CancelDirectiveNameClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// V5.2: [11.7.1] `bind` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_BindClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let arguments = (ins
+ OptionalAttr<BindKindAttr>:$bind_kind
+ );
+
+ let optAssemblyFormat = [{
+ `bind` `(` custom<ClauseAttr>($bind_kind) `)`
+ }];
+
+ let description = [{
+ The `bind` clause specifies the binding region of the construct on which it
+ appears.
+ }];
+}
+
+def OpenMP_BindClause : OpenMP_BindClauseSkip<>;
+
//===----------------------------------------------------------------------===//
// V5.2: [5.7.2] `copyprivate` clause
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 626539cb7bde42..4f16108a7a585f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -382,6 +382,50 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
// 2.9.2 Workshare Loop Construct
//===----------------------------------------------------------------------===//
+def LoopOp : OpenMP_Op<"loop", traits = [
+ DeclareOpInterfaceMethods<LoopWrapperInterface>, NoTerminator, SingleBlock,
+ AttrSizedOperandSegments
+ ], clauses = [
+ OpenMP_BindClause, OpenMP_PrivateClause, OpenMP_OrderClause,
+ OpenMP_ReductionClause
+ ], singleRegion = true> {
+ let summary = "loop construct";
+ let description = [{
+ A loop construct specifies that the logical iterations of the associated loops
+ may execute concurrently and permits the encountering threads to execute the
+ loop accordingly. A loop construct is a worksharing construct if its binding
+ region is the innermost enclosing parallel region. Otherwise it is not a work-
+ sharing region. The directive asserts that the iterations of the associated
+ loops may execute in any order, including concurrently. Each logical iteration
+ is executed once per instance of the loop region that is encountered by exactly
+ one thread that is a member of the binding thread set.
+
+ The body region can only contain a single block which must contain a single
+ operation, this operation must be an `omp.loop_nest`.
+
+ ```
+ omp.loop <clauses> {
+ omp.loop_nest (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
+ %a = load %arrA[%i1, %i2] : memref<?x?xf32>
+ %b = load %arrB[%i1, %i2] : memref<?x?xf32>
+ %sum = arith.addf %a, %b : f32
+ store %sum, %arrC[%i1, %i2] : memref<?x?xf32>
+ omp.yield
+ }
+ }
+ ```
+ }] # clausesDescription;
+
+ let assemblyFormat = clausesAssemblyFormat # [{
+ custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
+ $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $reduction_syms) attr-dict
+ }];
+
+ let hasVerifier = 1;
+ let hasRegionVerifier = 1;
+}
+
def WsloopOp : OpenMP_Op<"wsloop", traits = [
AttrSizedOperandSegments, DeclareOpInterfaceMethods<ComposableOpInterface>,
DeclareOpInterfaceMethods<LoopWrapperInterface>, NoTerminator,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e1df647d6a3c71..39accbadc8e276 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1948,6 +1948,29 @@ LogicalResult LoopWrapperInterface::verifyImpl() {
return success();
}
+//===----------------------------------------------------------------------===//
+// LoopOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult LoopOp::verify() {
+ return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
+ getReductionByref());
+}
+
+LogicalResult LoopOp::verifyRegions() {
+ Region ®ion = getRegion();
+
+ // Minimal amount of checks to verify the only nested op is an
+ // `omp.loop_nest`. A more extensive vierfication is done by the
+ // `LoopWrapperInterface` trait but the difference is that `omp.loop` cannot
+ // have another nested `LoopWrapperInterface`.
+ if (range_size(region.getOps()) != 1 || !isa<LoopNestOp>(*region.op_begin()))
+ return emitError() << "`omp.loop` expected to have a single nested "
+ "operation which is a `omp.loop_nest`";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// WsloopOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index fd89ec31c64a60..194daf8c2edf34 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2577,3 +2577,33 @@ func.func @omp_taskloop_invalid_composite(%lb: index, %ub: index, %step: index)
} {omp.composite}
return
}
+
+// -----
+
+func.func @omp_loop_invalid_nesting(%lb : index, %ub : index, %step : index) {
+
+ // expected-error @below {{`omp.loop` expected to have a single nested operation which is a `omp.loop_nest`}}
+ omp.loop {
+ omp.simd {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ } {omp.composite}
+ }
+
+ return
+}
+
+// -----
+
+func.func @omp_loop_invalid_nesting(%lb : index, %ub : index, %step : index) {
+
+ // expected-error @below {{custom op 'omp.loop' invalid clause value: 'dummy_value'}}
+ omp.loop bind(dummy_value) {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ }
+
+ return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 6f11b451fa00a3..47b8890b851c10 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2749,3 +2749,43 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
return
}
+
+// CHECK-LABEL: omp_loop
+func.func @omp_loop(%lb : index, %ub : index, %step : index) {
+ // CHECK: omp.loop {
+ omp.loop {
+ // CHECK: omp.loop_nest {{.*}} {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ // CHECK: omp.yield
+ omp.yield
+ // CHECK: }
+ }
+ // CHECK: }
+ }
+
+ // CHECK: omp.loop bind(teams) {
+ omp.loop bind(teams) {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ // CHECK: }
+ }
+
+ // CHECK: omp.loop bind(parallel) {
+ omp.loop bind(parallel) {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ // CHECK: }
+ }
+
+ // CHECK: omp.loop bind(thread) {
+ omp.loop bind(thread) {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ // CHECK: }
+ }
+
+ return
+}
More information about the llvm-commits
mailing list