[Mlir-commits] [mlir] [mlir][OpenMP] allow cancellation to not be directly nested (PR #134084)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 2 06:35:19 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
Author: Tom Eccles (tblah)
<details>
<summary>Changes</summary>
omp.cancel and omp.cancellationpoint contain an attribute describing the type of parent construct which should be cancelled. e.g.
```
!$omp cancel do
```
Must be inside of a wsloop. Previously the verifer required the immediate parent to be this operation. This is not quite right because something like the following is valid:
```
!$omp parallel do
do i = 1, N
if (cond) then
!$omp cancel do
endif
enddo
```
This patch relaxes the verifier to only require that some parent operation matches (not necessarily the immediate parent).
---
Full diff: https://github.com/llvm/llvm-project/pull/134084.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+10-22)
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+82)
``````````diff
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 882bc4071482f..e45d6d9fb3831 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3103,22 +3103,15 @@ void CancelOp::build(OpBuilder &builder, OperationState &state,
LogicalResult CancelOp::verify() {
ClauseCancellationConstructType cct = getCancelDirective();
- Operation *parentOp = (*this)->getParentOp();
-
- if (!parentOp) {
- return emitOpError() << "must be used within a region supporting "
- "cancel directive";
- }
+ Operation *thisOp = (*this).getOperation();
if ((cct == ClauseCancellationConstructType::Parallel) &&
- !isa<ParallelOp>(parentOp)) {
+ !thisOp->getParentOfType<ParallelOp>()) {
return emitOpError() << "cancel parallel must appear "
<< "inside a parallel region";
}
if (cct == ClauseCancellationConstructType::Loop) {
- auto loopOp = dyn_cast<LoopNestOp>(parentOp);
- auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
- loopOp ? loopOp->getParentOp() : nullptr);
+ auto wsloopOp = thisOp->getParentOfType<WsloopOp>();
if (!wsloopOp) {
return emitOpError()
@@ -3134,12 +3127,12 @@ LogicalResult CancelOp::verify() {
}
} else if (cct == ClauseCancellationConstructType::Sections) {
- if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
+ auto sectionsOp = thisOp->getParentOfType<SectionsOp>();
+ if (!sectionsOp) {
return emitOpError() << "cancel sections must appear "
<< "inside a sections region";
}
- if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
- cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
+ if (sectionsOp.getNowait()) {
return emitError() << "A sections construct that is canceled "
<< "must not have a nowait clause";
}
@@ -3159,25 +3152,20 @@ void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
LogicalResult CancellationPointOp::verify() {
ClauseCancellationConstructType cct = getCancelDirective();
- Operation *parentOp = (*this)->getParentOp();
-
- if (!parentOp) {
- return emitOpError() << "must be used within a region supporting "
- "cancellation point directive";
- }
+ Operation *thisOp = (*this).getOperation();
if ((cct == ClauseCancellationConstructType::Parallel) &&
- !(isa<ParallelOp>(parentOp))) {
+ !thisOp->getParentOfType<ParallelOp>()) {
return emitOpError() << "cancellation point parallel must appear "
<< "inside a parallel region";
}
if ((cct == ClauseCancellationConstructType::Loop) &&
- (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) {
+ !thisOp->getParentOfType<WsloopOp>()) {
return emitOpError() << "cancellation point loop must appear "
<< "inside a worksharing-loop region";
}
if ((cct == ClauseCancellationConstructType::Sections) &&
- !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
+ !thisOp->getParentOfType<SectionsOp>()) {
return emitOpError() << "cancellation point sections must appear "
<< "inside a sections region";
}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index a5cf789402726..378a841ae62df 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2201,6 +2201,48 @@ func.func @omp_cancel_sections() -> () {
return
}
+func.func @omp_cancel_parallel_nested(%if_cond : i1) -> () {
+ omp.parallel {
+ scf.if %if_cond {
+ // CHECK: omp.cancel cancellation_construct_type(parallel)
+ omp.cancel cancellation_construct_type(parallel)
+ }
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ return
+}
+
+func.func @omp_cancel_wsloop_nested(%lb : index, %ub : index, %step : index,
+ %if_cond : i1) {
+ omp.wsloop {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ scf.if %if_cond {
+ // CHECK: omp.cancel cancellation_construct_type(loop)
+ omp.cancel cancellation_construct_type(loop)
+ }
+ // CHECK: omp.yield
+ omp.yield
+ }
+ }
+ return
+}
+
+func.func @omp_cancel_sections_nested(%if_cond : i1) -> () {
+ omp.sections {
+ omp.section {
+ scf.if %if_cond {
+ // CHECK: omp.cancel cancellation_construct_type(sections)
+ omp.cancel cancellation_construct_type(sections)
+ }
+ omp.terminator
+ }
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ return
+}
+
func.func @omp_cancellationpoint_parallel() -> () {
omp.parallel {
// CHECK: omp.cancellation_point cancellation_construct_type(parallel)
@@ -2241,6 +2283,46 @@ func.func @omp_cancellationpoint_sections() -> () {
return
}
+func.func @omp_cancellationpoint_parallel_nested(%if_cond : i1) -> () {
+ omp.parallel {
+ scf.if %if_cond {
+ // CHECK: omp.cancellation_point cancellation_construct_type(parallel)
+ omp.cancellation_point cancellation_construct_type(parallel)
+ }
+ omp.terminator
+ }
+ return
+}
+
+func.func @omp_cancellationpoint_wsloop_nested(%lb : index, %ub : index, %step : index, %if_cond : i1) {
+ omp.wsloop {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ scf.if %if_cond {
+ // CHECK: omp.cancellation_point cancellation_construct_type(loop)
+ omp.cancellation_point cancellation_construct_type(loop)
+ }
+ // CHECK: omp.yield
+ omp.yield
+ }
+ }
+ return
+}
+
+func.func @omp_cancellationpoint_sections_nested(%if_cond : i1) -> () {
+ omp.sections {
+ omp.section {
+ scf.if %if_cond {
+ // CHECK: omp.cancellation_point cancellation_construct_type(sections)
+ omp.cancellation_point cancellation_construct_type(sections)
+ }
+ omp.terminator
+ }
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ return
+}
+
// CHECK-LABEL: @omp_taskgroup_no_tasks
func.func @omp_taskgroup_no_tasks() -> () {
``````````
</details>
https://github.com/llvm/llvm-project/pull/134084
More information about the Mlir-commits
mailing list