[Mlir-commits] [mlir] [mlir][OpenMP] allow cancellation to not be directly nested (PR #134084)

Tom Eccles llvmlistbot at llvm.org
Wed Apr 2 06:34:44 PDT 2025


https://github.com/tblah created https://github.com/llvm/llvm-project/pull/134084

omp.cancel and omp.cancellationpoint contain an attribute describing the type of parent construct which should be cancelled. e.g.
```
!$omp cancel do
```
Must be inside of a wsloop. Previously the verifer required the immediate parent to be this operation. This is not quite right because something like the following is valid:
```
!$omp parallel do
do i = 1, N
  if (cond) then
    !$omp cancel do
  endif
enddo
```

This patch relaxes the verifier to only require that some parent operation matches (not necessarily the immediate parent).

>From 66d0f84a3db19c46acdafb96997bba602b9b354f Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 2 Apr 2025 13:21:41 +0000
Subject: [PATCH] [mlir][OpenMP] allow cancellation to not be directly nested

omp.cancel and omp.cancellationpoint contain an attribute describing the
type of parent construct which should be cancelled. e.g.
```
!$omp cancel do
```
Must be inside of a wsloop. Previously the verifer required the
immediate parent to be this operation. This is not quite right because
something like the following is valid:
```
!$omp parallel do
do i = 1, N
  if (cond) then
    !$omp cancel do
  endif
enddo
```

This patch relaxes the verifier to only require that some parent
operation matches (not necessarily the immediate parent).
---
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 32 +++-----
 mlir/test/Dialect/OpenMP/ops.mlir            | 82 ++++++++++++++++++++
 2 files changed, 92 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 882bc4071482f..e45d6d9fb3831 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3103,22 +3103,15 @@ void CancelOp::build(OpBuilder &builder, OperationState &state,
 
 LogicalResult CancelOp::verify() {
   ClauseCancellationConstructType cct = getCancelDirective();
-  Operation *parentOp = (*this)->getParentOp();
-
-  if (!parentOp) {
-    return emitOpError() << "must be used within a region supporting "
-                            "cancel directive";
-  }
+  Operation *thisOp = (*this).getOperation();
 
   if ((cct == ClauseCancellationConstructType::Parallel) &&
-      !isa<ParallelOp>(parentOp)) {
+      !thisOp->getParentOfType<ParallelOp>()) {
     return emitOpError() << "cancel parallel must appear "
                          << "inside a parallel region";
   }
   if (cct == ClauseCancellationConstructType::Loop) {
-    auto loopOp = dyn_cast<LoopNestOp>(parentOp);
-    auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
-        loopOp ? loopOp->getParentOp() : nullptr);
+    auto wsloopOp = thisOp->getParentOfType<WsloopOp>();
 
     if (!wsloopOp) {
       return emitOpError()
@@ -3134,12 +3127,12 @@ LogicalResult CancelOp::verify() {
     }
 
   } else if (cct == ClauseCancellationConstructType::Sections) {
-    if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
+    auto sectionsOp = thisOp->getParentOfType<SectionsOp>();
+    if (!sectionsOp) {
       return emitOpError() << "cancel sections must appear "
                            << "inside a sections region";
     }
-    if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
-        cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
+    if (sectionsOp.getNowait()) {
       return emitError() << "A sections construct that is canceled "
                          << "must not have a nowait clause";
     }
@@ -3159,25 +3152,20 @@ void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
 
 LogicalResult CancellationPointOp::verify() {
   ClauseCancellationConstructType cct = getCancelDirective();
-  Operation *parentOp = (*this)->getParentOp();
-
-  if (!parentOp) {
-    return emitOpError() << "must be used within a region supporting "
-                            "cancellation point directive";
-  }
+  Operation *thisOp = (*this).getOperation();
 
   if ((cct == ClauseCancellationConstructType::Parallel) &&
-      !(isa<ParallelOp>(parentOp))) {
+      !thisOp->getParentOfType<ParallelOp>()) {
     return emitOpError() << "cancellation point parallel must appear "
                          << "inside a parallel region";
   }
   if ((cct == ClauseCancellationConstructType::Loop) &&
-      (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) {
+      !thisOp->getParentOfType<WsloopOp>()) {
     return emitOpError() << "cancellation point loop must appear "
                          << "inside a worksharing-loop region";
   }
   if ((cct == ClauseCancellationConstructType::Sections) &&
-      !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
+      !thisOp->getParentOfType<SectionsOp>()) {
     return emitOpError() << "cancellation point sections must appear "
                          << "inside a sections region";
   }
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index a5cf789402726..378a841ae62df 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2201,6 +2201,48 @@ func.func @omp_cancel_sections() -> () {
   return
 }
 
+func.func @omp_cancel_parallel_nested(%if_cond : i1) -> () {
+  omp.parallel {
+    scf.if %if_cond {
+      // CHECK: omp.cancel cancellation_construct_type(parallel)
+      omp.cancel cancellation_construct_type(parallel)
+    }
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
+func.func @omp_cancel_wsloop_nested(%lb : index, %ub : index, %step : index,
+                                    %if_cond : i1) {
+  omp.wsloop {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      scf.if %if_cond {
+        // CHECK: omp.cancel cancellation_construct_type(loop)
+        omp.cancel cancellation_construct_type(loop)
+      }
+      // CHECK: omp.yield
+      omp.yield
+    }
+  }
+  return
+}
+
+func.func @omp_cancel_sections_nested(%if_cond : i1) -> () {
+  omp.sections {
+    omp.section {
+      scf.if %if_cond {
+        // CHECK: omp.cancel cancellation_construct_type(sections)
+        omp.cancel cancellation_construct_type(sections)
+      }
+      omp.terminator
+    }
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
 func.func @omp_cancellationpoint_parallel() -> () {
   omp.parallel {
     // CHECK: omp.cancellation_point cancellation_construct_type(parallel)
@@ -2241,6 +2283,46 @@ func.func @omp_cancellationpoint_sections() -> () {
   return
 }
 
+func.func @omp_cancellationpoint_parallel_nested(%if_cond : i1) -> () {
+  omp.parallel {
+    scf.if %if_cond {
+      // CHECK: omp.cancellation_point cancellation_construct_type(parallel)
+      omp.cancellation_point cancellation_construct_type(parallel)
+    }
+    omp.terminator
+  }
+  return
+}
+
+func.func @omp_cancellationpoint_wsloop_nested(%lb : index, %ub : index, %step : index, %if_cond : i1) {
+  omp.wsloop {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      scf.if %if_cond {
+        // CHECK: omp.cancellation_point cancellation_construct_type(loop)
+        omp.cancellation_point cancellation_construct_type(loop)
+      }
+      // CHECK: omp.yield
+      omp.yield
+    }
+  }
+  return
+}
+
+func.func @omp_cancellationpoint_sections_nested(%if_cond : i1) -> () {
+  omp.sections {
+    omp.section {
+      scf.if %if_cond {
+        // CHECK: omp.cancellation_point cancellation_construct_type(sections)
+        omp.cancellation_point cancellation_construct_type(sections)
+      }
+      omp.terminator
+    }
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
 // CHECK-LABEL: @omp_taskgroup_no_tasks
 func.func @omp_taskgroup_no_tasks() -> () {
 



More information about the Mlir-commits mailing list