[Mlir-commits] [mlir] [mlir][OpenMP] convert wsloop cancellation to LLVMIR (PR #137194)

Tom Eccles llvmlistbot at llvm.org
Tue Apr 29 09:24:37 PDT 2025


https://github.com/tblah updated https://github.com/llvm/llvm-project/pull/137194

>From a51f9e45ed09d382cdc4a1526929330c2bc463cc Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Tue, 15 Apr 2025 15:05:50 +0000
Subject: [PATCH] [mlir][OpenMP] convert wsloop cancellation to LLVMIR

Taskloop support will follow in a later patch.
---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 40 ++++++++-
 mlir/test/Target/LLVMIR/openmp-cancel.mlir    | 87 +++++++++++++++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      | 16 ----
 3 files changed, 125 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 16c93f902d9ec..48d6dee6f5905 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -161,8 +161,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
   auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
     omp::ClauseCancellationConstructType cancelledDirective =
         op.getCancelDirective();
-    if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel &&
-        cancelledDirective != omp::ClauseCancellationConstructType::Sections)
+    if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
       result = todo("cancel directive construct type not yet supported");
   };
   auto checkDepend = [&todo](auto op, LogicalResult &result) {
@@ -2358,6 +2357,30 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
           ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
           : llvm::omp::WorksharingLoopType::ForStaticLoop;
 
+  SmallVector<llvm::BranchInst *> cancelTerminators;
+  // This callback is invoked only if there is cancellation inside of the wsloop
+  // body.
+  auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
+    llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder;
+    llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
+
+    // ip is currently in the block branched to if cancellation occured.
+    // We need to create a branch to terminate that block.
+    llvmBuilder.restoreIP(ip);
+
+    // We must still clean up the wsloop after cancelling it, so we need to
+    // branch to the block that finalizes the wsloop.
+    // That block has not been created yet so use this block as a dummy for now
+    // and fix this after creating the wsloop.
+    cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
+    return llvm::Error::success();
+  };
+  // We have to add the cleanup to the OpenMPIRBuilder before the body gets
+  // created in case the body contains omp.cancel (which will then expect to be
+  // able to find this cleanup callback).
+  ompBuilder->pushFinalizationCB({finiCB, llvm::omp::Directive::OMPD_for,
+                                  constructIsCancellable(wsloopOp)});
+
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
   llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
       wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
@@ -2379,6 +2402,19 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
 
+  ompBuilder->popFinalizationCB();
+  if (!cancelTerminators.empty()) {
+    // If we cancelled the loop, we should branch to the finalization block of
+    // the wsloop (which is always immediately before the loop continuation
+    // block). Now the finalization has been created, we can fix the branch.
+    llvm::BasicBlock *wsloopFini = wsloopIP->getBlock()->getSinglePredecessor();
+    for (llvm::BranchInst *cancelBranch : cancelTerminators) {
+      assert(cancelBranch->getNumSuccessors() == 1 &&
+             "cancel branch should have one target");
+      cancelBranch->setSuccessor(0, wsloopFini);
+    }
+  }
+
   // Process the reductions if required.
   if (failed(createReductionsAndCleanup(
           wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
diff --git a/mlir/test/Target/LLVMIR/openmp-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-cancel.mlir
index fca16b936fc85..3c195a98d1000 100644
--- a/mlir/test/Target/LLVMIR/openmp-cancel.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-cancel.mlir
@@ -156,3 +156,90 @@ llvm.func @cancel_sections_if(%cond : i1) {
 // CHECK:         ret void
 // CHECK:       .cncl:                                            ; preds = %[[VAL_27]]
 // CHECK:         br label %[[VAL_19]]
+
+llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
+  omp.wsloop {
+    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      omp.cancel cancellation_construct_type(loop) if(%cond)
+      omp.yield
+    }
+  }
+  llvm.return
+}
+// CHECK-LABEL: define void @cancel_wsloop_if
+// CHECK:         %[[VAL_0:.*]] = alloca i32, align 4
+// CHECK:         %[[VAL_1:.*]] = alloca i32, align 4
+// CHECK:         %[[VAL_2:.*]] = alloca i32, align 4
+// CHECK:         %[[VAL_3:.*]] = alloca i32, align 4
+// CHECK:         br label %[[VAL_4:.*]]
+// CHECK:       omp.region.after_alloca:                          ; preds = %[[VAL_5:.*]]
+// CHECK:         br label %[[VAL_6:.*]]
+// CHECK:       entry:                                            ; preds = %[[VAL_4]]
+// CHECK:         br label %[[VAL_7:.*]]
+// CHECK:       omp.wsloop.region:                                ; preds = %[[VAL_6]]
+// CHECK:         %[[VAL_8:.*]] = icmp slt i32 %[[VAL_9:.*]], 0
+// CHECK:         %[[VAL_10:.*]] = sub i32 0, %[[VAL_9]]
+// CHECK:         %[[VAL_11:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_10]], i32 %[[VAL_9]]
+// CHECK:         %[[VAL_12:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_13:.*]], i32 %[[VAL_14:.*]]
+// CHECK:         %[[VAL_15:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_14]], i32 %[[VAL_13]]
+// CHECK:         %[[VAL_16:.*]] = sub nsw i32 %[[VAL_15]], %[[VAL_12]]
+// CHECK:         %[[VAL_17:.*]] = icmp sle i32 %[[VAL_15]], %[[VAL_12]]
+// CHECK:         %[[VAL_18:.*]] = sub i32 %[[VAL_16]], 1
+// CHECK:         %[[VAL_19:.*]] = udiv i32 %[[VAL_18]], %[[VAL_11]]
+// CHECK:         %[[VAL_20:.*]] = add i32 %[[VAL_19]], 1
+// CHECK:         %[[VAL_21:.*]] = icmp ule i32 %[[VAL_16]], %[[VAL_11]]
+// CHECK:         %[[VAL_22:.*]] = select i1 %[[VAL_21]], i32 1, i32 %[[VAL_20]]
+// CHECK:         %[[VAL_23:.*]] = select i1 %[[VAL_17]], i32 0, i32 %[[VAL_22]]
+// CHECK:         br label %[[VAL_24:.*]]
+// CHECK:       omp_loop.preheader:                               ; preds = %[[VAL_7]]
+// CHECK:         store i32 0, ptr %[[VAL_1]], align 4
+// CHECK:         %[[VAL_25:.*]] = sub i32 %[[VAL_23]], 1
+// CHECK:         store i32 %[[VAL_25]], ptr %[[VAL_2]], align 4
+// CHECK:         store i32 1, ptr %[[VAL_3]], align 4
+// CHECK:         %[[VAL_26:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK:         call void @__kmpc_for_static_init_4u(ptr @1, i32 %[[VAL_26]], i32 34, ptr %[[VAL_0]], ptr %[[VAL_1]], ptr %[[VAL_2]], ptr %[[VAL_3]], i32 1, i32 0)
+// CHECK:         %[[VAL_27:.*]] = load i32, ptr %[[VAL_1]], align 4
+// CHECK:         %[[VAL_28:.*]] = load i32, ptr %[[VAL_2]], align 4
+// CHECK:         %[[VAL_29:.*]] = sub i32 %[[VAL_28]], %[[VAL_27]]
+// CHECK:         %[[VAL_30:.*]] = add i32 %[[VAL_29]], 1
+// CHECK:         br label %[[VAL_31:.*]]
+// CHECK:       omp_loop.header:                                  ; preds = %[[VAL_32:.*]], %[[VAL_24]]
+// CHECK:         %[[VAL_33:.*]] = phi i32 [ 0, %[[VAL_24]] ], [ %[[VAL_34:.*]], %[[VAL_32]] ]
+// CHECK:         br label %[[VAL_35:.*]]
+// CHECK:       omp_loop.cond:                                    ; preds = %[[VAL_31]]
+// CHECK:         %[[VAL_36:.*]] = icmp ult i32 %[[VAL_33]], %[[VAL_30]]
+// CHECK:         br i1 %[[VAL_36]], label %[[VAL_37:.*]], label %[[VAL_38:.*]]
+// CHECK:       omp_loop.body:                                    ; preds = %[[VAL_35]]
+// CHECK:         %[[VAL_39:.*]] = add i32 %[[VAL_33]], %[[VAL_27]]
+// CHECK:         %[[VAL_40:.*]] = mul i32 %[[VAL_39]], %[[VAL_9]]
+// CHECK:         %[[VAL_41:.*]] = add i32 %[[VAL_40]], %[[VAL_14]]
+// CHECK:         br label %[[VAL_42:.*]]
+// CHECK:       omp.loop_nest.region:                             ; preds = %[[VAL_37]]
+// CHECK:         br i1 %[[VAL_43:.*]], label %[[VAL_44:.*]], label %[[VAL_45:.*]]
+// CHECK:       25:                                               ; preds = %[[VAL_42]]
+// CHECK:         %[[VAL_46:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK:         %[[VAL_47:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_46]], i32 2)
+// CHECK:         %[[VAL_48:.*]] = icmp eq i32 %[[VAL_47]], 0
+// CHECK:         br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_50:.*]]
+// CHECK:       .split:                                           ; preds = %[[VAL_44]]
+// CHECK:         br label %[[VAL_51:.*]]
+// CHECK:       28:                                               ; preds = %[[VAL_42]]
+// CHECK:         br label %[[VAL_51]]
+// CHECK:       29:                                               ; preds = %[[VAL_45]], %[[VAL_49]]
+// CHECK:         br label %[[VAL_52:.*]]
+// CHECK:       omp.region.cont1:                                 ; preds = %[[VAL_51]]
+// CHECK:         br label %[[VAL_32]]
+// CHECK:       omp_loop.inc:                                     ; preds = %[[VAL_52]]
+// CHECK:         %[[VAL_34]] = add nuw i32 %[[VAL_33]], 1
+// CHECK:         br label %[[VAL_31]]
+// CHECK:       omp_loop.exit:                                    ; preds = %[[VAL_50]], %[[VAL_35]]
+// CHECK:         call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_26]])
+// CHECK:         %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK:         call void @__kmpc_barrier(ptr @2, i32 %[[VAL_53]])
+// CHECK:         br label %[[VAL_54:.*]]
+// CHECK:       omp_loop.after:                                   ; preds = %[[VAL_38]]
+// CHECK:         br label %[[VAL_55:.*]]
+// CHECK:       omp.region.cont:                                  ; preds = %[[VAL_54]]
+// CHECK:         ret void
+// CHECK:       .cncl:                                            ; preds = %[[VAL_44]]
+// CHECK:         br label %[[VAL_38]]
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 0cc96deacd954..ed355096b702e 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -26,22 +26,6 @@ llvm.func @atomic_hint(%v : !llvm.ptr, %x : !llvm.ptr, %expr : i32) {
 
 // -----
 
-llvm.func @cancel_wsloop(%lb : i32, %ub : i32, %step: i32) {
-  // expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}
-  omp.wsloop {
-    // expected-error at below {{LLVM Translation failed for operation: omp.loop_nest}}
-    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
-      // expected-error at below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancel operation}}
-      // expected-error at below {{LLVM Translation failed for operation: omp.cancel}}
-      omp.cancel cancellation_construct_type(loop)
-      omp.yield
-    }
-  }
-  llvm.return
-}
-
-// -----
-
 llvm.func @cancel_taskgroup() {
   // expected-error at below {{LLVM Translation failed for operation: omp.taskgroup}}
   omp.taskgroup {



More information about the Mlir-commits mailing list