[Mlir-commits] [mlir] 7734138 - [mlir][OpenMP] allow cancellation to not be directly nested (#134084)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 14 02:39:06 PDT 2025
Author: Tom Eccles
Date: 2025-04-14T10:39:03+01:00
New Revision: 77341388a77b1442b3a54d745fc269dabb175f0c
URL: https://github.com/llvm/llvm-project/commit/77341388a77b1442b3a54d745fc269dabb175f0c
DIFF: https://github.com/llvm/llvm-project/commit/77341388a77b1442b3a54d745fc269dabb175f0c.diff
LOG: [mlir][OpenMP] allow cancellation to not be directly nested (#134084)
omp.cancel and omp.cancellationpoint contain an attribute describing the
type of parent construct which should be cancelled. e.g.
```
!$omp cancel do
```
Must be inside of a wsloop. Previously the verifer required the
immediate parent to be this operation. This is not quite right because
something like the following is valid:
```
!$omp parallel do
do i = 1, N
if (cond) then
!$omp cancel do
endif
enddo
```
This patch relaxes the verifier to only require that some parent
operation matches (not necessarily the immediate parent).
Added:
Modified:
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/test/Dialect/OpenMP/invalid.mlir
mlir/test/Dialect/OpenMP/ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index cf0b4bf6e95ed..dd701da507fc6 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3162,24 +3162,32 @@ void CancelOp::build(OpBuilder &builder, OperationState &state,
CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
}
+static Operation *getParentInSameDialect(Operation *thisOp) {
+ Operation *parent = thisOp->getParentOp();
+ while (parent) {
+ if (parent->getDialect() == thisOp->getDialect())
+ return parent;
+ parent = parent->getParentOp();
+ }
+ return nullptr;
+}
+
LogicalResult CancelOp::verify() {
ClauseCancellationConstructType cct = getCancelDirective();
- Operation *parentOp = (*this)->getParentOp();
-
- if (!parentOp) {
- return emitOpError() << "must be used within a region supporting "
- "cancel directive";
- }
+ // The next OpenMP operation in the chain of parents
+ Operation *structuralParent = getParentInSameDialect((*this).getOperation());
+ if (!structuralParent)
+ return emitOpError() << "Orphaned cancel construct";
if ((cct == ClauseCancellationConstructType::Parallel) &&
- !isa<ParallelOp>(parentOp)) {
+ !mlir::isa<ParallelOp>(structuralParent)) {
return emitOpError() << "cancel parallel must appear "
<< "inside a parallel region";
}
if (cct == ClauseCancellationConstructType::Loop) {
- auto loopOp = dyn_cast<LoopNestOp>(parentOp);
- auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
- loopOp ? loopOp->getParentOp() : nullptr);
+ // structural parent will be omp.loop_nest, directly nested inside
+ // omp.wsloop
+ auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
if (!wsloopOp) {
return emitOpError()
@@ -3195,12 +3203,15 @@ LogicalResult CancelOp::verify() {
}
} else if (cct == ClauseCancellationConstructType::Sections) {
- if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
+ // structural parent will be an omp.section, directly nested inside
+ // omp.sections
+ auto sectionsOp =
+ mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
+ if (!sectionsOp) {
return emitOpError() << "cancel sections must appear "
<< "inside a sections region";
}
- if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
- cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
+ if (sectionsOp.getNowait()) {
return emitError() << "A sections construct that is canceled "
<< "must not have a nowait clause";
}
@@ -3220,25 +3231,25 @@ void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
LogicalResult CancellationPointOp::verify() {
ClauseCancellationConstructType cct = getCancelDirective();
- Operation *parentOp = (*this)->getParentOp();
-
- if (!parentOp) {
- return emitOpError() << "must be used within a region supporting "
- "cancellation point directive";
- }
+ // The next OpenMP operation in the chain of parents
+ Operation *structuralParent = getParentInSameDialect((*this).getOperation());
+ if (!structuralParent)
+ return emitOpError() << "Orphaned cancellation point";
if ((cct == ClauseCancellationConstructType::Parallel) &&
- !(isa<ParallelOp>(parentOp))) {
+ !mlir::isa<ParallelOp>(structuralParent)) {
return emitOpError() << "cancellation point parallel must appear "
<< "inside a parallel region";
}
+ // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
+ // find the wsloop
if ((cct == ClauseCancellationConstructType::Loop) &&
- (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) {
+ !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
return emitOpError() << "cancellation point loop must appear "
<< "inside a worksharing-loop region";
}
if ((cct == ClauseCancellationConstructType::Sections) &&
- !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
+ !mlir::isa<omp::SectionOp>(structuralParent)) {
return emitOpError() << "cancellation point sections must appear "
<< "inside a sections region";
}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index e08adb08f7e99..41e17881afd30 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1710,6 +1710,14 @@ func.func @omp_task(%mem: memref<1xf32>) {
// -----
+func.func @omp_cancel() {
+ // expected-error @below {{Orphaned cancel construct}}
+ omp.cancel cancellation_construct_type(parallel)
+ return
+}
+
+// -----
+
func.func @omp_cancel() {
omp.sections {
// expected-error @below {{cancel parallel must appear inside a parallel region}}
@@ -1789,6 +1797,14 @@ func.func @omp_cancel5() -> () {
// -----
+func.func @omp_cancellationpoint() {
+ // expected-error @below {{Orphaned cancellation point}}
+ omp.cancellation_point cancellation_construct_type(parallel)
+ return
+}
+
+// -----
+
func.func @omp_cancellationpoint() {
omp.sections {
// expected-error @below {{cancellation point parallel must appear inside a parallel region}}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 6bc2500471997..d5e2bfa5d3949 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2201,6 +2201,48 @@ func.func @omp_cancel_sections() -> () {
return
}
+func.func @omp_cancel_parallel_nested(%if_cond : i1) -> () {
+ omp.parallel {
+ scf.if %if_cond {
+ // CHECK: omp.cancel cancellation_construct_type(parallel)
+ omp.cancel cancellation_construct_type(parallel)
+ }
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ return
+}
+
+func.func @omp_cancel_wsloop_nested(%lb : index, %ub : index, %step : index,
+ %if_cond : i1) {
+ omp.wsloop {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ scf.if %if_cond {
+ // CHECK: omp.cancel cancellation_construct_type(loop)
+ omp.cancel cancellation_construct_type(loop)
+ }
+ // CHECK: omp.yield
+ omp.yield
+ }
+ }
+ return
+}
+
+func.func @omp_cancel_sections_nested(%if_cond : i1) -> () {
+ omp.sections {
+ omp.section {
+ scf.if %if_cond {
+ // CHECK: omp.cancel cancellation_construct_type(sections)
+ omp.cancel cancellation_construct_type(sections)
+ }
+ omp.terminator
+ }
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ return
+}
+
func.func @omp_cancellationpoint_parallel() -> () {
omp.parallel {
// CHECK: omp.cancellation_point cancellation_construct_type(parallel)
@@ -2241,6 +2283,46 @@ func.func @omp_cancellationpoint_sections() -> () {
return
}
+func.func @omp_cancellationpoint_parallel_nested(%if_cond : i1) -> () {
+ omp.parallel {
+ scf.if %if_cond {
+ // CHECK: omp.cancellation_point cancellation_construct_type(parallel)
+ omp.cancellation_point cancellation_construct_type(parallel)
+ }
+ omp.terminator
+ }
+ return
+}
+
+func.func @omp_cancellationpoint_wsloop_nested(%lb : index, %ub : index, %step : index, %if_cond : i1) {
+ omp.wsloop {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ scf.if %if_cond {
+ // CHECK: omp.cancellation_point cancellation_construct_type(loop)
+ omp.cancellation_point cancellation_construct_type(loop)
+ }
+ // CHECK: omp.yield
+ omp.yield
+ }
+ }
+ return
+}
+
+func.func @omp_cancellationpoint_sections_nested(%if_cond : i1) -> () {
+ omp.sections {
+ omp.section {
+ scf.if %if_cond {
+ // CHECK: omp.cancellation_point cancellation_construct_type(sections)
+ omp.cancellation_point cancellation_construct_type(sections)
+ }
+ omp.terminator
+ }
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ return
+}
+
// CHECK-LABEL: @omp_taskgroup_no_tasks
func.func @omp_taskgroup_no_tasks() -> () {
More information about the Mlir-commits
mailing list