[Mlir-commits] [mlir] [mlir][OpenMP] allow cancellation to not be directly nested (PR #134084)

Tom Eccles llvmlistbot at llvm.org
Thu Apr 3 03:07:40 PDT 2025


https://github.com/tblah updated https://github.com/llvm/llvm-project/pull/134084

>From 66d0f84a3db19c46acdafb96997bba602b9b354f Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 2 Apr 2025 13:21:41 +0000
Subject: [PATCH 1/2] [mlir][OpenMP] allow cancellation to not be directly
 nested

omp.cancel and omp.cancellationpoint contain an attribute describing the
type of parent construct which should be cancelled. e.g.
```
!$omp cancel do
```
Must be inside of a wsloop. Previously the verifer required the
immediate parent to be this operation. This is not quite right because
something like the following is valid:
```
!$omp parallel do
do i = 1, N
  if (cond) then
    !$omp cancel do
  endif
enddo
```

This patch relaxes the verifier to only require that some parent
operation matches (not necessarily the immediate parent).
---
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 32 +++-----
 mlir/test/Dialect/OpenMP/ops.mlir            | 82 ++++++++++++++++++++
 2 files changed, 92 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 882bc4071482f..e45d6d9fb3831 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3103,22 +3103,15 @@ void CancelOp::build(OpBuilder &builder, OperationState &state,
 
 LogicalResult CancelOp::verify() {
   ClauseCancellationConstructType cct = getCancelDirective();
-  Operation *parentOp = (*this)->getParentOp();
-
-  if (!parentOp) {
-    return emitOpError() << "must be used within a region supporting "
-                            "cancel directive";
-  }
+  Operation *thisOp = (*this).getOperation();
 
   if ((cct == ClauseCancellationConstructType::Parallel) &&
-      !isa<ParallelOp>(parentOp)) {
+      !thisOp->getParentOfType<ParallelOp>()) {
     return emitOpError() << "cancel parallel must appear "
                          << "inside a parallel region";
   }
   if (cct == ClauseCancellationConstructType::Loop) {
-    auto loopOp = dyn_cast<LoopNestOp>(parentOp);
-    auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
-        loopOp ? loopOp->getParentOp() : nullptr);
+    auto wsloopOp = thisOp->getParentOfType<WsloopOp>();
 
     if (!wsloopOp) {
       return emitOpError()
@@ -3134,12 +3127,12 @@ LogicalResult CancelOp::verify() {
     }
 
   } else if (cct == ClauseCancellationConstructType::Sections) {
-    if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
+    auto sectionsOp = thisOp->getParentOfType<SectionsOp>();
+    if (!sectionsOp) {
       return emitOpError() << "cancel sections must appear "
                            << "inside a sections region";
     }
-    if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
-        cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
+    if (sectionsOp.getNowait()) {
       return emitError() << "A sections construct that is canceled "
                          << "must not have a nowait clause";
     }
@@ -3159,25 +3152,20 @@ void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
 
 LogicalResult CancellationPointOp::verify() {
   ClauseCancellationConstructType cct = getCancelDirective();
-  Operation *parentOp = (*this)->getParentOp();
-
-  if (!parentOp) {
-    return emitOpError() << "must be used within a region supporting "
-                            "cancellation point directive";
-  }
+  Operation *thisOp = (*this).getOperation();
 
   if ((cct == ClauseCancellationConstructType::Parallel) &&
-      !(isa<ParallelOp>(parentOp))) {
+      !thisOp->getParentOfType<ParallelOp>()) {
     return emitOpError() << "cancellation point parallel must appear "
                          << "inside a parallel region";
   }
   if ((cct == ClauseCancellationConstructType::Loop) &&
-      (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) {
+      !thisOp->getParentOfType<WsloopOp>()) {
     return emitOpError() << "cancellation point loop must appear "
                          << "inside a worksharing-loop region";
   }
   if ((cct == ClauseCancellationConstructType::Sections) &&
-      !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
+      !thisOp->getParentOfType<SectionsOp>()) {
     return emitOpError() << "cancellation point sections must appear "
                          << "inside a sections region";
   }
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index a5cf789402726..378a841ae62df 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2201,6 +2201,48 @@ func.func @omp_cancel_sections() -> () {
   return
 }
 
+func.func @omp_cancel_parallel_nested(%if_cond : i1) -> () {
+  omp.parallel {
+    scf.if %if_cond {
+      // CHECK: omp.cancel cancellation_construct_type(parallel)
+      omp.cancel cancellation_construct_type(parallel)
+    }
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
+func.func @omp_cancel_wsloop_nested(%lb : index, %ub : index, %step : index,
+                                    %if_cond : i1) {
+  omp.wsloop {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      scf.if %if_cond {
+        // CHECK: omp.cancel cancellation_construct_type(loop)
+        omp.cancel cancellation_construct_type(loop)
+      }
+      // CHECK: omp.yield
+      omp.yield
+    }
+  }
+  return
+}
+
+func.func @omp_cancel_sections_nested(%if_cond : i1) -> () {
+  omp.sections {
+    omp.section {
+      scf.if %if_cond {
+        // CHECK: omp.cancel cancellation_construct_type(sections)
+        omp.cancel cancellation_construct_type(sections)
+      }
+      omp.terminator
+    }
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
 func.func @omp_cancellationpoint_parallel() -> () {
   omp.parallel {
     // CHECK: omp.cancellation_point cancellation_construct_type(parallel)
@@ -2241,6 +2283,46 @@ func.func @omp_cancellationpoint_sections() -> () {
   return
 }
 
+func.func @omp_cancellationpoint_parallel_nested(%if_cond : i1) -> () {
+  omp.parallel {
+    scf.if %if_cond {
+      // CHECK: omp.cancellation_point cancellation_construct_type(parallel)
+      omp.cancellation_point cancellation_construct_type(parallel)
+    }
+    omp.terminator
+  }
+  return
+}
+
+func.func @omp_cancellationpoint_wsloop_nested(%lb : index, %ub : index, %step : index, %if_cond : i1) {
+  omp.wsloop {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      scf.if %if_cond {
+        // CHECK: omp.cancellation_point cancellation_construct_type(loop)
+        omp.cancellation_point cancellation_construct_type(loop)
+      }
+      // CHECK: omp.yield
+      omp.yield
+    }
+  }
+  return
+}
+
+func.func @omp_cancellationpoint_sections_nested(%if_cond : i1) -> () {
+  omp.sections {
+    omp.section {
+      scf.if %if_cond {
+        // CHECK: omp.cancellation_point cancellation_construct_type(sections)
+        omp.cancellation_point cancellation_construct_type(sections)
+      }
+      omp.terminator
+    }
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
 // CHECK-LABEL: @omp_taskgroup_no_tasks
 func.func @omp_taskgroup_no_tasks() -> () {
 

>From da8639e63dd0b73972b2e4b6537b3c1745e7289b Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Thu, 3 Apr 2025 10:06:19 +0000
Subject: [PATCH 2/2] Remove cancel and cancellation point verifiers entirely

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |   4 -
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  63 ----------
 mlir/test/Dialect/OpenMP/invalid.mlir         | 115 ------------------
 3 files changed, 182 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 65095932be627..046c5da91cda3 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1655,8 +1655,6 @@ def CancelOp : OpenMP_Op<"cancel", clauses = [
   let builders = [
     OpBuilder<(ins CArg<"const CancelOperands &">:$clauses)>
   ];
-
-  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1675,8 +1673,6 @@ def CancellationPointOp : OpenMP_Op<"cancellation_point", clauses = [
   let builders = [
     OpBuilder<(ins CArg<"const CancellationPointOperands &">:$clauses)>
   ];
-
-  let hasVerifier = 1;
 }
 
 def ScanOp : OpenMP_Op<"scan", [
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e45d6d9fb3831..6a6994b3e484c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3101,46 +3101,6 @@ void CancelOp::build(OpBuilder &builder, OperationState &state,
   CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
 }
 
-LogicalResult CancelOp::verify() {
-  ClauseCancellationConstructType cct = getCancelDirective();
-  Operation *thisOp = (*this).getOperation();
-
-  if ((cct == ClauseCancellationConstructType::Parallel) &&
-      !thisOp->getParentOfType<ParallelOp>()) {
-    return emitOpError() << "cancel parallel must appear "
-                         << "inside a parallel region";
-  }
-  if (cct == ClauseCancellationConstructType::Loop) {
-    auto wsloopOp = thisOp->getParentOfType<WsloopOp>();
-
-    if (!wsloopOp) {
-      return emitOpError()
-             << "cancel loop must appear inside a worksharing-loop region";
-    }
-    if (wsloopOp.getNowaitAttr()) {
-      return emitError() << "A worksharing construct that is canceled "
-                         << "must not have a nowait clause";
-    }
-    if (wsloopOp.getOrderedAttr()) {
-      return emitError() << "A worksharing construct that is canceled "
-                         << "must not have an ordered clause";
-    }
-
-  } else if (cct == ClauseCancellationConstructType::Sections) {
-    auto sectionsOp = thisOp->getParentOfType<SectionsOp>();
-    if (!sectionsOp) {
-      return emitOpError() << "cancel sections must appear "
-                           << "inside a sections region";
-    }
-    if (sectionsOp.getNowait()) {
-      return emitError() << "A sections construct that is canceled "
-                         << "must not have a nowait clause";
-    }
-  }
-  // TODO : Add more when we support taskgroup.
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // CancellationPointOp
 //===----------------------------------------------------------------------===//
@@ -3150,29 +3110,6 @@ void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
   CancellationPointOp::build(builder, state, clauses.cancelDirective);
 }
 
-LogicalResult CancellationPointOp::verify() {
-  ClauseCancellationConstructType cct = getCancelDirective();
-  Operation *thisOp = (*this).getOperation();
-
-  if ((cct == ClauseCancellationConstructType::Parallel) &&
-      !thisOp->getParentOfType<ParallelOp>()) {
-    return emitOpError() << "cancellation point parallel must appear "
-                         << "inside a parallel region";
-  }
-  if ((cct == ClauseCancellationConstructType::Loop) &&
-      !thisOp->getParentOfType<WsloopOp>()) {
-    return emitOpError() << "cancellation point loop must appear "
-                         << "inside a worksharing-loop region";
-  }
-  if ((cct == ClauseCancellationConstructType::Sections) &&
-      !thisOp->getParentOfType<SectionsOp>()) {
-    return emitOpError() << "cancellation point sections must appear "
-                         << "inside a sections region";
-  }
-  // TODO : Add more when we support taskgroup.
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // MapBoundsOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 403128bb2300e..2aafe082624f1 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1710,121 +1710,6 @@ func.func @omp_task(%mem: memref<1xf32>) {
 
 // -----
 
-func.func @omp_cancel() {
-  omp.sections {
-    // expected-error @below {{cancel parallel must appear inside a parallel region}}
-    omp.cancel cancellation_construct_type(parallel)
-    // CHECK: omp.terminator
-    omp.terminator
-  }
-  return
-}
-
-// -----
-
-func.func @omp_cancel1() {
-  omp.parallel {
-    // expected-error @below {{cancel sections must appear inside a sections region}}
-    omp.cancel cancellation_construct_type(sections)
-    // CHECK: omp.terminator
-    omp.terminator
-  }
-  return
-}
-
-// -----
-
-func.func @omp_cancel2() {
-  omp.sections {
-    // expected-error @below {{cancel loop must appear inside a worksharing-loop region}}
-    omp.cancel cancellation_construct_type(loop)
-    // CHECK: omp.terminator
-    omp.terminator
-  }
-  return
-}
-
-// -----
-
-func.func @omp_cancel3(%arg1 : i32, %arg2 : i32, %arg3 : i32) -> () {
-  omp.wsloop nowait {
-    omp.loop_nest (%0) : i32 = (%arg1) to (%arg2) step (%arg3) {
-      // expected-error @below {{A worksharing construct that is canceled must not have a nowait clause}}
-      omp.cancel cancellation_construct_type(loop)
-      // CHECK: omp.yield
-      omp.yield
-    }
-  }
-  return
-}
-
-// -----
-
-func.func @omp_cancel4(%arg1 : i32, %arg2 : i32, %arg3 : i32) -> () {
-  omp.wsloop ordered(1) {
-    omp.loop_nest (%0) : i32 = (%arg1) to (%arg2) step (%arg3) {
-      // expected-error @below {{A worksharing construct that is canceled must not have an ordered clause}}
-      omp.cancel cancellation_construct_type(loop)
-      // CHECK: omp.yield
-      omp.yield
-    }
-  }
-  return
-}
-
-// -----
-
-func.func @omp_cancel5() -> () {
-  omp.sections nowait {
-    omp.section {
-      // expected-error @below {{A sections construct that is canceled must not have a nowait clause}}
-      omp.cancel cancellation_construct_type(sections)
-      omp.terminator
-    }
-    // CHECK: omp.terminator
-    omp.terminator
-  }
-  return
-}
-
-// -----
-
-func.func @omp_cancellationpoint() {
-  omp.sections {
-    // expected-error @below {{cancellation point parallel must appear inside a parallel region}}
-    omp.cancellation_point cancellation_construct_type(parallel)
-    // CHECK: omp.terminator
-    omp.terminator
-  }
-  return
-}
-
-// -----
-
-func.func @omp_cancellationpoint1() {
-  omp.parallel {
-    // expected-error @below {{cancellation point sections must appear inside a sections region}}
-    omp.cancellation_point cancellation_construct_type(sections)
-    // CHECK: omp.terminator
-    omp.terminator
-  }
-  return
-}
-
-// -----
-
-func.func @omp_cancellationpoint2() {
-  omp.sections {
-    // expected-error @below {{cancellation point loop must appear inside a worksharing-loop region}}
-    omp.cancellation_point cancellation_construct_type(loop)
-    // CHECK: omp.terminator
-    omp.terminator
-  }
-  return
-}
-
-// -----
-
 omp.declare_reduction @add_f32 : f32
 init {
  ^bb0(%arg: f32):



More information about the Mlir-commits mailing list