[Mlir-commits] [mlir] [mlir][OpenMP] Convert omp.cancel parallel to LLVMIR (PR #137192)

Tom Eccles llvmlistbot at llvm.org
Thu Apr 24 08:20:37 PDT 2025


https://github.com/tblah created https://github.com/llvm/llvm-project/pull/137192

Support for other constructs will follow in subsequent PRs.

>From ba6d59cdb2bf906a60b7e13448af730bd2019140 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Tue, 8 Apr 2025 17:21:15 +0000
Subject: [PATCH] [mlir][OpenMP] Convert omp.cancel parallel to LLVMIR

Support for other constructs will follow in subsequent PRs.
---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 69 +++++++++++++++-
 mlir/test/Target/LLVMIR/openmp-cancel.mlir    | 82 +++++++++++++++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      | 48 +++++++++--
 3 files changed, 191 insertions(+), 8 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-cancel.mlir

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 52aa1fbfab2c1..6185a433a8199 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -158,6 +158,12 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getBare())
       result = todo("ompx_bare");
   };
+  auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
+    omp::ClauseCancellationConstructType cancelledDirective =
+        op.getCancelDirective();
+    if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel)
+      result = todo("cancel directive");
+  };
   auto checkDepend = [&todo](auto op, LogicalResult &result) {
     if (!op.getDependVars().empty() || op.getDependKinds())
       result = todo("depend");
@@ -248,6 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
 
   LogicalResult result = success();
   llvm::TypeSwitch<Operation &>(op)
+      .Case([&](omp::CancelOp op) { checkCancelDirective(op, result); })
       .Case([&](omp::DistributeOp op) {
         checkAllocate(op, result);
         checkDistSchedule(op, result);
@@ -1580,6 +1587,21 @@ cleanupPrivateVars(llvm::IRBuilderBase &builder,
   return success();
 }
 
+/// Returns true if the construct contains omp.cancel or omp.cancellation_point
+static bool constructIsCancellable(Operation *op) {
+  // omp.cancel must be "closely nested" so it will be visible and not inside of
+  // funcion calls. This is enforced by the verifier.
+  bool containsCancel = false;
+  op->walk([&containsCancel](Operation *child) {
+    if (mlir::isa<omp::CancelOp>(child)) {
+      containsCancel = true;
+      return WalkResult::interrupt();
+    }
+    return WalkResult::advance();
+  });
+  return containsCancel;
+}
+
 static LogicalResult
 convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
                    LLVM::ModuleTranslation &moduleTranslation) {
@@ -2524,8 +2546,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
   if (auto bind = opInst.getProcBindKind())
     pbKind = getProcBindKind(*bind);
-  // TODO: Is the Parallel construct cancellable?
-  bool isCancellable = false;
+  bool isCancellable = constructIsCancellable(opInst);
 
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
@@ -2991,6 +3012,47 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
   return success();
 }
 
+static llvm::omp::Directive convertCancellationConstructType(
+    omp::ClauseCancellationConstructType directive) {
+  switch (directive) {
+  case omp::ClauseCancellationConstructType::Loop:
+    return llvm::omp::Directive::OMPD_for;
+  case omp::ClauseCancellationConstructType::Parallel:
+    return llvm::omp::Directive::OMPD_parallel;
+  case omp::ClauseCancellationConstructType::Sections:
+    return llvm::omp::Directive::OMPD_sections;
+  case omp::ClauseCancellationConstructType::Taskgroup:
+    return llvm::omp::Directive::OMPD_taskgroup;
+  }
+}
+
+static LogicalResult
+convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
+                 LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+
+  if (failed(checkImplementationStatus(*op.getOperation())))
+    return failure();
+
+  llvm::Value *ifCond = nullptr;
+  if (Value ifVar = op.getIfExpr())
+    ifCond = moduleTranslation.lookupValue(ifVar);
+
+  llvm::omp::Directive cancelledDirective =
+      convertCancellationConstructType(op.getCancelDirective());
+
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
+      ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
+
+  if (failed(handleError(afterIP, *op.getOperation())))
+    return failure();
+
+  builder.restoreIP(afterIP.get());
+
+  return success();
+}
+
 /// Converts an OpenMP Threadprivate operation into LLVM IR using
 /// OpenMPIRBuilder.
 static LogicalResult
@@ -5421,6 +5483,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
           .Case([&](omp::AtomicCaptureOp op) {
             return convertOmpAtomicCapture(op, builder, moduleTranslation);
           })
+          .Case([&](omp::CancelOp op) {
+            return convertOmpCancel(op, builder, moduleTranslation);
+          })
           .Case([&](omp::SectionsOp) {
             return convertOmpSections(*op, builder, moduleTranslation);
           })
diff --git a/mlir/test/Target/LLVMIR/openmp-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-cancel.mlir
new file mode 100644
index 0000000000000..1f67d6ceb34af
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-cancel.mlir
@@ -0,0 +1,82 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @cancel_parallel() {
+  omp.parallel {
+    omp.cancel cancellation_construct_type(parallel)
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK-LABEL: define internal void @cancel_parallel..omp_par
+// CHECK:       omp.par.entry:
+// CHECK:         %[[VAL_5:.*]] = alloca i32, align 4
+// CHECK:         %[[VAL_6:.*]] = load i32, ptr %[[VAL_7:.*]], align 4
+// CHECK:         store i32 %[[VAL_6]], ptr %[[VAL_5]], align 4
+// CHECK:         %[[VAL_8:.*]] = load i32, ptr %[[VAL_5]], align 4
+// CHECK:         br label %[[VAL_9:.*]]
+// CHECK:       omp.region.after_alloca:                          ; preds = %[[VAL_10:.*]]
+// CHECK:         br label %[[VAL_11:.*]]
+// CHECK:       omp.par.region:                                   ; preds = %[[VAL_9]]
+// CHECK:         br label %[[VAL_12:.*]]
+// CHECK:       omp.par.region1:                                  ; preds = %[[VAL_11]]
+// CHECK:         %[[VAL_13:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK:         %[[VAL_14:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_13]], i32 1)
+// CHECK:         %[[VAL_15:.*]] = icmp eq i32 %[[VAL_14]], 0
+// CHECK:         br i1 %[[VAL_15]], label %[[VAL_16:.*]], label %[[VAL_17:.*]]
+// CHECK:       omp.par.region1.cncl:                             ; preds = %[[VAL_12]]
+// CHECK:         %[[VAL_18:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK:         %[[VAL_19:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_18]])
+// CHECK:         br label %[[VAL_20:.*]]
+// CHECK:       omp.par.region1.split:                            ; preds = %[[VAL_12]]
+// CHECK:         br label %[[VAL_21:.*]]
+// CHECK:       omp.region.cont:                                  ; preds = %[[VAL_16]]
+// CHECK:         br label %[[VAL_22:.*]]
+// CHECK:       omp.par.pre_finalize:                             ; preds = %[[VAL_21]]
+// CHECK:         br label %[[VAL_20]]
+// CHECK:       omp.par.exit.exitStub:                            ; preds = %[[VAL_22]], %[[VAL_17]]
+// CHECK:         ret void
+
+llvm.func @cancel_parallel_if(%arg0 : i1) {
+  omp.parallel {
+    omp.cancel cancellation_construct_type(parallel) if(%arg0)
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK-LABEL: define internal void @cancel_parallel_if..omp_par
+// CHECK:       omp.par.entry:
+// CHECK:         %[[VAL_9:.*]] = getelementptr { ptr }, ptr %[[VAL_10:.*]], i32 0, i32 0
+// CHECK:         %[[VAL_11:.*]] = load ptr, ptr %[[VAL_9]], align 8
+// CHECK:         %[[VAL_12:.*]] = alloca i32, align 4
+// CHECK:         %[[VAL_13:.*]] = load i32, ptr %[[VAL_14:.*]], align 4
+// CHECK:         store i32 %[[VAL_13]], ptr %[[VAL_12]], align 4
+// CHECK:         %[[VAL_15:.*]] = load i32, ptr %[[VAL_12]], align 4
+// CHECK:         %[[VAL_16:.*]] = load i1, ptr %[[VAL_11]], align 1
+// CHECK:         br label %[[VAL_17:.*]]
+// CHECK:       omp.region.after_alloca:                          ; preds = %[[VAL_18:.*]]
+// CHECK:         br label %[[VAL_19:.*]]
+// CHECK:       omp.par.region:                                   ; preds = %[[VAL_17]]
+// CHECK:         br label %[[VAL_20:.*]]
+// CHECK:       omp.par.region1:                                  ; preds = %[[VAL_19]]
+// CHECK:         br i1 %[[VAL_16]], label %[[VAL_21:.*]], label %[[VAL_22:.*]]
+// CHECK:       3:                                                ; preds = %[[VAL_20]]
+// CHECK:         br label %[[VAL_23:.*]]
+// CHECK:       4:                                                ; preds = %[[VAL_22]], %[[VAL_24:.*]]
+// CHECK:         br label %[[VAL_25:.*]]
+// CHECK:       omp.region.cont:                                  ; preds = %[[VAL_23]]
+// CHECK:         br label %[[VAL_26:.*]]
+// CHECK:       omp.par.pre_finalize:                             ; preds = %[[VAL_25]]
+// CHECK:         br label %[[VAL_27:.*]]
+// CHECK:       5:                                                ; preds = %[[VAL_20]]
+// CHECK:         %[[VAL_28:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK:         %[[VAL_29:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_28]], i32 1)
+// CHECK:         %[[VAL_30:.*]] = icmp eq i32 %[[VAL_29]], 0
+// CHECK:         br i1 %[[VAL_30]], label %[[VAL_24]], label %[[VAL_31:.*]]
+// CHECK:       .cncl:                                            ; preds = %[[VAL_21]]
+// CHECK:         %[[VAL_32:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK:         %[[VAL_33:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_32]])
+// CHECK:         br label %[[VAL_27]]
+// CHECK:       .split:                                           ; preds = %[[VAL_21]]
+// CHECK:         br label %[[VAL_23]]
+// CHECK:       omp.par.exit.exitStub:                            ; preds = %[[VAL_31]], %[[VAL_26]]
+// CHECK:         ret void
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 7eafe396082e4..bf251ac2b7d0a 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -26,12 +26,48 @@ llvm.func @atomic_hint(%v : !llvm.ptr, %x : !llvm.ptr, %expr : i32) {
 
 // -----
 
-llvm.func @cancel() {
-  // expected-error at below {{LLVM Translation failed for operation: omp.parallel}}
-  omp.parallel {
-    // expected-error at below {{not yet implemented: omp.cancel}}
-    // expected-error at below {{LLVM Translation failed for operation: omp.cancel}}
-    omp.cancel cancellation_construct_type(parallel)
+llvm.func @cancel_wsloop(%lb : i32, %ub : i32, %step: i32) {
+  // expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}
+  omp.wsloop {
+    // expected-error at below {{LLVM Translation failed for operation: omp.loop_nest}}
+    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      // expected-error at below {{not yet implemented: Unhandled clause cancel directive in omp.cancel operation}}
+      // expected-error at below {{LLVM Translation failed for operation: omp.cancel}}
+      omp.cancel cancellation_construct_type(loop)
+      omp.yield
+    }
+  }
+  llvm.return
+}
+
+// -----
+
+llvm.func @cancel_sections() {
+  // expected-error at below {{LLVM Translation failed for operation: omp.sections}}
+  omp.sections {
+    omp.section {
+      // expected-error at below {{not yet implemented: Unhandled clause cancel directive in omp.cancel operation}}
+      // expected-error at below {{LLVM Translation failed for operation: omp.cancel}}
+      omp.cancel cancellation_construct_type(sections)
+      omp.terminator
+    }
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
+llvm.func @cancel_taskgroup() {
+  // expected-error at below {{LLVM Translation failed for operation: omp.taskgroup}}
+  omp.taskgroup {
+    // expected-error at below {{LLVM Translation failed for operation: omp.task}}
+    omp.task {
+      // expected-error at below {{not yet implemented: Unhandled clause cancel directive in omp.cancel operation}}
+      // expected-error at below {{LLVM Translation failed for operation: omp.cancel}}
+      omp.cancel cancellation_construct_type(taskgroup)
+      omp.terminator
+    }
     omp.terminator
   }
   llvm.return



More information about the Mlir-commits mailing list