[Mlir-commits] [mlir] 7734138 - [mlir][OpenMP] allow cancellation to not be directly nested (#134084)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 14 02:39:06 PDT 2025


Author: Tom Eccles
Date: 2025-04-14T10:39:03+01:00
New Revision: 77341388a77b1442b3a54d745fc269dabb175f0c

URL: https://github.com/llvm/llvm-project/commit/77341388a77b1442b3a54d745fc269dabb175f0c
DIFF: https://github.com/llvm/llvm-project/commit/77341388a77b1442b3a54d745fc269dabb175f0c.diff

LOG: [mlir][OpenMP] allow cancellation to not be directly nested (#134084)

omp.cancel and omp.cancellationpoint contain an attribute describing the
type of parent construct which should be cancelled. e.g.
```
!$omp cancel do
```
Must be inside of a wsloop. Previously the verifer required the
immediate parent to be this operation. This is not quite right because
something like the following is valid:
```
!$omp parallel do
do i = 1, N
  if (cond) then
    !$omp cancel do
  endif
enddo
```

This patch relaxes the verifier to only require that some parent
operation matches (not necessarily the immediate parent).

Added: 
    

Modified: 
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index cf0b4bf6e95ed..dd701da507fc6 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3162,24 +3162,32 @@ void CancelOp::build(OpBuilder &builder, OperationState &state,
   CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
 }
 
+static Operation *getParentInSameDialect(Operation *thisOp) {
+  Operation *parent = thisOp->getParentOp();
+  while (parent) {
+    if (parent->getDialect() == thisOp->getDialect())
+      return parent;
+    parent = parent->getParentOp();
+  }
+  return nullptr;
+}
+
 LogicalResult CancelOp::verify() {
   ClauseCancellationConstructType cct = getCancelDirective();
-  Operation *parentOp = (*this)->getParentOp();
-
-  if (!parentOp) {
-    return emitOpError() << "must be used within a region supporting "
-                            "cancel directive";
-  }
+  // The next OpenMP operation in the chain of parents
+  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
+  if (!structuralParent)
+    return emitOpError() << "Orphaned cancel construct";
 
   if ((cct == ClauseCancellationConstructType::Parallel) &&
-      !isa<ParallelOp>(parentOp)) {
+      !mlir::isa<ParallelOp>(structuralParent)) {
     return emitOpError() << "cancel parallel must appear "
                          << "inside a parallel region";
   }
   if (cct == ClauseCancellationConstructType::Loop) {
-    auto loopOp = dyn_cast<LoopNestOp>(parentOp);
-    auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
-        loopOp ? loopOp->getParentOp() : nullptr);
+    // structural parent will be omp.loop_nest, directly nested inside
+    // omp.wsloop
+    auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
 
     if (!wsloopOp) {
       return emitOpError()
@@ -3195,12 +3203,15 @@ LogicalResult CancelOp::verify() {
     }
 
   } else if (cct == ClauseCancellationConstructType::Sections) {
-    if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
+    // structural parent will be an omp.section, directly nested inside
+    // omp.sections
+    auto sectionsOp =
+        mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
+    if (!sectionsOp) {
       return emitOpError() << "cancel sections must appear "
                            << "inside a sections region";
     }
-    if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
-        cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
+    if (sectionsOp.getNowait()) {
       return emitError() << "A sections construct that is canceled "
                          << "must not have a nowait clause";
     }
@@ -3220,25 +3231,25 @@ void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
 
 LogicalResult CancellationPointOp::verify() {
   ClauseCancellationConstructType cct = getCancelDirective();
-  Operation *parentOp = (*this)->getParentOp();
-
-  if (!parentOp) {
-    return emitOpError() << "must be used within a region supporting "
-                            "cancellation point directive";
-  }
+  // The next OpenMP operation in the chain of parents
+  Operation *structuralParent = getParentInSameDialect((*this).getOperation());
+  if (!structuralParent)
+    return emitOpError() << "Orphaned cancellation point";
 
   if ((cct == ClauseCancellationConstructType::Parallel) &&
-      !(isa<ParallelOp>(parentOp))) {
+      !mlir::isa<ParallelOp>(structuralParent)) {
     return emitOpError() << "cancellation point parallel must appear "
                          << "inside a parallel region";
   }
+  // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
+  // find the wsloop
   if ((cct == ClauseCancellationConstructType::Loop) &&
-      (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) {
+      !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
     return emitOpError() << "cancellation point loop must appear "
                          << "inside a worksharing-loop region";
   }
   if ((cct == ClauseCancellationConstructType::Sections) &&
-      !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
+      !mlir::isa<omp::SectionOp>(structuralParent)) {
     return emitOpError() << "cancellation point sections must appear "
                          << "inside a sections region";
   }

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index e08adb08f7e99..41e17881afd30 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1710,6 +1710,14 @@ func.func @omp_task(%mem: memref<1xf32>) {
 
 // -----
 
+func.func @omp_cancel() {
+  // expected-error @below {{Orphaned cancel construct}}
+  omp.cancel cancellation_construct_type(parallel)
+  return
+}
+
+// -----
+
 func.func @omp_cancel() {
   omp.sections {
     // expected-error @below {{cancel parallel must appear inside a parallel region}}
@@ -1789,6 +1797,14 @@ func.func @omp_cancel5() -> () {
 
 // -----
 
+func.func @omp_cancellationpoint() {
+  // expected-error @below {{Orphaned cancellation point}}
+  omp.cancellation_point cancellation_construct_type(parallel)
+  return
+}
+
+// -----
+
 func.func @omp_cancellationpoint() {
   omp.sections {
     // expected-error @below {{cancellation point parallel must appear inside a parallel region}}

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 6bc2500471997..d5e2bfa5d3949 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2201,6 +2201,48 @@ func.func @omp_cancel_sections() -> () {
   return
 }
 
+func.func @omp_cancel_parallel_nested(%if_cond : i1) -> () {
+  omp.parallel {
+    scf.if %if_cond {
+      // CHECK: omp.cancel cancellation_construct_type(parallel)
+      omp.cancel cancellation_construct_type(parallel)
+    }
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
+func.func @omp_cancel_wsloop_nested(%lb : index, %ub : index, %step : index,
+                                    %if_cond : i1) {
+  omp.wsloop {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      scf.if %if_cond {
+        // CHECK: omp.cancel cancellation_construct_type(loop)
+        omp.cancel cancellation_construct_type(loop)
+      }
+      // CHECK: omp.yield
+      omp.yield
+    }
+  }
+  return
+}
+
+func.func @omp_cancel_sections_nested(%if_cond : i1) -> () {
+  omp.sections {
+    omp.section {
+      scf.if %if_cond {
+        // CHECK: omp.cancel cancellation_construct_type(sections)
+        omp.cancel cancellation_construct_type(sections)
+      }
+      omp.terminator
+    }
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
 func.func @omp_cancellationpoint_parallel() -> () {
   omp.parallel {
     // CHECK: omp.cancellation_point cancellation_construct_type(parallel)
@@ -2241,6 +2283,46 @@ func.func @omp_cancellationpoint_sections() -> () {
   return
 }
 
+func.func @omp_cancellationpoint_parallel_nested(%if_cond : i1) -> () {
+  omp.parallel {
+    scf.if %if_cond {
+      // CHECK: omp.cancellation_point cancellation_construct_type(parallel)
+      omp.cancellation_point cancellation_construct_type(parallel)
+    }
+    omp.terminator
+  }
+  return
+}
+
+func.func @omp_cancellationpoint_wsloop_nested(%lb : index, %ub : index, %step : index, %if_cond : i1) {
+  omp.wsloop {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      scf.if %if_cond {
+        // CHECK: omp.cancellation_point cancellation_construct_type(loop)
+        omp.cancellation_point cancellation_construct_type(loop)
+      }
+      // CHECK: omp.yield
+      omp.yield
+    }
+  }
+  return
+}
+
+func.func @omp_cancellationpoint_sections_nested(%if_cond : i1) -> () {
+  omp.sections {
+    omp.section {
+      scf.if %if_cond {
+        // CHECK: omp.cancellation_point cancellation_construct_type(sections)
+        omp.cancellation_point cancellation_construct_type(sections)
+      }
+      omp.terminator
+    }
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
 // CHECK-LABEL: @omp_taskgroup_no_tasks
 func.func @omp_taskgroup_no_tasks() -> () {
 


        


More information about the Mlir-commits mailing list