[Mlir-commits] [mlir] e402009 - [mlir][OpenMP] cancel(lation point) taskgroup LLVMIR (#137841)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 8 03:16:02 PDT 2025
Author: Tom Eccles
Date: 2025-05-08T11:15:58+01:00
New Revision: e40200901cf1af860db9ded5c03b7b104396e429
URL: https://github.com/llvm/llvm-project/commit/e40200901cf1af860db9ded5c03b7b104396e429
DIFF: https://github.com/llvm/llvm-project/commit/e40200901cf1af860db9ded5c03b7b104396e429.diff
LOG: [mlir][OpenMP] cancel(lation point) taskgroup LLVMIR (#137841)
A cancel or cancellation point for taskgroup is always nested inside of
a task inside of the taskgroup. For the task which is cancelled, it is
that task which needs to be cleaned up: not the owning taskgroup.
Therefore the cancellation branch handler is done in the conversion of
the task not in conversion of taskgroup.
I added a firstprivate clause to the test for cancel taskgroup to
demonstrate that the block being branched to is the same block where
mandatory cleanup code is added. Cancellation point follows exactly the
same code path.
Added:
Modified:
flang/docs/OpenMPSupport.md
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/test/Target/LLVMIR/openmp-cancel.mlir
mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir
mlir/test/Target/LLVMIR/openmp-todo.mlir
Removed:
################################################################################
diff --git a/flang/docs/OpenMPSupport.md b/flang/docs/OpenMPSupport.md
index 587723890d226..28e13d3179bd2 100644
--- a/flang/docs/OpenMPSupport.md
+++ b/flang/docs/OpenMPSupport.md
@@ -51,8 +51,8 @@ Note : No distinction is made between the support in Parser/Semantics, MLIR, Low
| depend clause | P | depend clause with array sections are not supported |
| declare reduction construct | N | |
| atomic construct extensions | Y | |
-| cancel construct | N | |
-| cancellation point construct | N | |
+| cancel construct | Y | |
+| cancellation point construct | Y | |
| parallel do simd construct | P | linear clause is not supported |
| target teams construct | P | device and reduction clauses are not supported |
| teams distribute construct | P | reduction and dist_schedule clauses not supported |
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 2e8e5a6ca5c2a..9f7b5605556e6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -161,8 +161,18 @@ static LogicalResult checkImplementationStatus(Operation &op) {
auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
omp::ClauseCancellationConstructType cancelledDirective =
op.getCancelDirective();
- if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
- result = todo("cancel directive construct type not yet supported");
+ // Cancelling a taskloop is not yet supported because we don't yet have LLVM
+ // IR conversion for taskloop
+ if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
+ Operation *parent = op->getParentOp();
+ while (parent) {
+ if (parent->getDialect() == op->getDialect())
+ break;
+ parent = parent->getParentOp();
+ }
+ if (isa_and_nonnull<omp::TaskloopOp>(parent))
+ result = todo("cancel directive inside of taskloop");
+ }
};
auto checkDepend = [&todo](auto op, LogicalResult &result) {
if (!op.getDependVars().empty() || op.getDependKinds())
@@ -1889,6 +1899,55 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
}
}
+/// Shared implementation of a callback which adds a termiator for the new block
+/// created for the branch taken when an openmp construct is cancelled. The
+/// terminator is saved in \p cancelTerminators. This callback is invoked only
+/// if there is cancellation inside of the taskgroup body.
+/// The terminator will need to be fixed to branch to the correct block to
+/// cleanup the construct.
+static void
+pushCancelFinalizationCB(SmallVectorImpl<llvm::BranchInst *> &cancelTerminators,
+ llvm::IRBuilderBase &llvmBuilder,
+ llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
+ llvm::omp::Directive cancelDirective) {
+ auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
+ llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
+
+ // ip is currently in the block branched to if cancellation occured.
+ // We need to create a branch to terminate that block.
+ llvmBuilder.restoreIP(ip);
+
+ // We must still clean up the construct after cancelling it, so we need to
+ // branch to the block that finalizes the taskgroup.
+ // That block has not been created yet so use this block as a dummy for now
+ // and fix this after creating the operation.
+ cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
+ return llvm::Error::success();
+ };
+ // We have to add the cleanup to the OpenMPIRBuilder before the body gets
+ // created in case the body contains omp.cancel (which will then expect to be
+ // able to find this cleanup callback).
+ ompBuilder.pushFinalizationCB(
+ {finiCB, cancelDirective, constructIsCancellable(op)});
+}
+
+/// If we cancelled the construct, we should branch to the finalization block of
+/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
+/// is immediately before the continuation block. Now this finalization has
+/// been created we can fix the branch.
+static void
+popCancelFinalizationCB(const ArrayRef<llvm::BranchInst *> cancelTerminators,
+ llvm::OpenMPIRBuilder &ompBuilder,
+ const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
+ ompBuilder.popFinalizationCB();
+ llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
+ for (llvm::BranchInst *cancelBranch : cancelTerminators) {
+ assert(cancelBranch->getNumSuccessors() == 1 &&
+ "cancel branch should have one target");
+ cancelBranch->setSuccessor(0, constructFini);
+ }
+}
+
namespace {
/// TaskContextStructManager takes care of creating and freeing a structure
/// containing information needed by the task body to execute.
@@ -2202,6 +2261,14 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
return llvm::Error::success();
};
+ llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
+ SmallVector<llvm::BranchInst *> cancelTerminators;
+ // The directive to match here is OMPD_taskgroup because it is the taskgroup
+ // which is canceled. This is handled here because it is the task's cleanup
+ // block which should be branched to.
+ pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
+ llvm::omp::Directive::OMPD_taskgroup);
+
SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
moduleTranslation, dds);
@@ -2219,6 +2286,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
if (failed(handleError(afterIP, *taskOp)))
return failure();
+ // Set the correct branch target for task cancellation
+ popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
+
builder.restoreIP(*afterIP);
return success();
}
@@ -2349,28 +2419,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
: llvm::omp::WorksharingLoopType::ForStaticLoop;
SmallVector<llvm::BranchInst *> cancelTerminators;
- // This callback is invoked only if there is cancellation inside of the wsloop
- // body.
- auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
- llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder;
- llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
-
- // ip is currently in the block branched to if cancellation occured.
- // We need to create a branch to terminate that block.
- llvmBuilder.restoreIP(ip);
-
- // We must still clean up the wsloop after cancelling it, so we need to
- // branch to the block that finalizes the wsloop.
- // That block has not been created yet so use this block as a dummy for now
- // and fix this after creating the wsloop.
- cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
- return llvm::Error::success();
- };
- // We have to add the cleanup to the OpenMPIRBuilder before the body gets
- // created in case the body contains omp.cancel (which will then expect to be
- // able to find this cleanup callback).
- ompBuilder->pushFinalizationCB({finiCB, llvm::omp::Directive::OMPD_for,
- constructIsCancellable(wsloopOp)});
+ pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
+ llvm::omp::Directive::OMPD_for);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
@@ -2393,18 +2443,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(wsloopIP, opInst)))
return failure();
- ompBuilder->popFinalizationCB();
- if (!cancelTerminators.empty()) {
- // If we cancelled the loop, we should branch to the finalization block of
- // the wsloop (which is always immediately before the loop continuation
- // block). Now the finalization has been created, we can fix the branch.
- llvm::BasicBlock *wsloopFini = wsloopIP->getBlock()->getSinglePredecessor();
- for (llvm::BranchInst *cancelBranch : cancelTerminators) {
- assert(cancelBranch->getNumSuccessors() == 1 &&
- "cancel branch should have one target");
- cancelBranch->setSuccessor(0, wsloopFini);
- }
- }
+ // Set the correct branch target for task cancellation
+ popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
// Process the reductions if required.
if (failed(createReductionsAndCleanup(
@@ -3060,12 +3100,12 @@ static llvm::omp::Directive convertCancellationConstructType(
static LogicalResult
convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
- llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
-
if (failed(checkImplementationStatus(*op.getOperation())))
return failure();
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+
llvm::Value *ifCond = nullptr;
if (Value ifVar = op.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
@@ -3088,12 +3128,12 @@ static LogicalResult
convertOmpCancellationPoint(omp::CancellationPointOp op,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
- llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
-
if (failed(checkImplementationStatus(*op.getOperation())))
return failure();
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+
llvm::omp::Directive cancelledDirective =
convertCancellationConstructType(op.getCancelDirective());
diff --git a/mlir/test/Target/LLVMIR/openmp-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-cancel.mlir
index 3c195a98d1000..21241702ad569 100644
--- a/mlir/test/Target/LLVMIR/openmp-cancel.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-cancel.mlir
@@ -243,3 +243,51 @@ llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
// CHECK: ret void
// CHECK: .cncl: ; preds = %[[VAL_44]]
// CHECK: br label %[[VAL_38]]
+
+omp.private {type = firstprivate} @i32_priv : i32 copy {
+^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ %0 = llvm.load %arg0 : !llvm.ptr -> i32
+ llvm.store %0, %arg1 : i32, !llvm.ptr
+ omp.yield(%arg1 : !llvm.ptr)
+}
+
+llvm.func @do_something(!llvm.ptr)
+
+llvm.func @cancel_taskgroup(%arg0: !llvm.ptr) {
+ omp.taskgroup {
+// Using firstprivate clause so we have some end of task cleanup to branch to
+// after the cancellation.
+ omp.task private(@i32_priv %arg0 -> %arg1 : !llvm.ptr) {
+ omp.cancel cancellation_construct_type(taskgroup)
+ llvm.call @do_something(%arg1) : (!llvm.ptr) -> ()
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+}
+// CHECK-LABEL: define internal void @cancel_taskgroup..omp_par(
+// CHECK: task.alloca:
+// CHECK: %[[VAL_21:.*]] = load ptr, ptr %[[VAL_22:.*]], align 8
+// CHECK: %[[VAL_23:.*]] = getelementptr { ptr }, ptr %[[VAL_21]], i32 0, i32 0
+// CHECK: %[[VAL_24:.*]] = load ptr, ptr %[[VAL_23]], align 8, !align !1
+// CHECK: br label %[[VAL_25:.*]]
+// CHECK: task.body: ; preds = %[[VAL_26:.*]]
+// CHECK: %[[VAL_27:.*]] = getelementptr { i32 }, ptr %[[VAL_24]], i32 0, i32 0
+// CHECK: br label %[[VAL_28:.*]]
+// CHECK: omp.task.region: ; preds = %[[VAL_25]]
+// CHECK: %[[VAL_29:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK: %[[VAL_30:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_29]], i32 4)
+// CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0
+// CHECK: br i1 %[[VAL_31]], label %omp.task.region.split, label %omp.task.region.cncl
+// CHECK: omp.task.region.cncl:
+// CHECK: br label %omp.region.cont2
+// CHECK: omp.region.cont2:
+// Both cancellation and normal paths reach the end-of-task cleanup:
+// CHECK: tail call void @free(ptr %[[VAL_24]])
+// CHECK: br label %task.exit.exitStub
+// CHECK: omp.task.region.split:
+// CHECK: call void @do_something(ptr %[[VAL_27]])
+// CHECK: br label %omp.region.cont2
+// CHECK: task.exit.exitStub:
+// CHECK: ret void
diff --git a/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir
index bbb313c113567..5e0d3f9f7e293 100644
--- a/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir
@@ -186,3 +186,33 @@ llvm.func @cancellation_point_wsloop(%lb : i32, %ub : i32, %step : i32) {
// CHECK: ret void
// CHECK: omp.loop_nest.region.cncl: ; preds = %[[VAL_100]]
// CHECK: br label %[[VAL_96]]
+
+
+llvm.func @cancellation_point_taskgroup() {
+ omp.taskgroup {
+ omp.task {
+ omp.cancellation_point cancellation_construct_type(taskgroup)
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+}
+// CHECK-LABEL: define internal void @cancellation_point_taskgroup..omp_par(
+// CHECK: task.alloca:
+// CHECK: br label %[[VAL_50:.*]]
+// CHECK: task.body: ; preds = %[[VAL_51:.*]]
+// CHECK: br label %[[VAL_52:.*]]
+// CHECK: omp.task.region: ; preds = %[[VAL_50]]
+// CHECK: %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK: %[[VAL_54:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[VAL_53]], i32 4)
+// CHECK: %[[VAL_55:.*]] = icmp eq i32 %[[VAL_54]], 0
+// CHECK: br i1 %[[VAL_55]], label %omp.task.region.split, label %omp.task.region.cncl
+// CHECK: omp.task.region.cncl:
+// CHECK: br label %omp.region.cont1
+// CHECK: omp.region.cont1:
+// CHECK: br label %task.exit.exitStub
+// CHECK: omp.task.region.split:
+// CHECK: br label %omp.region.cont1
+// CHECK: task.exit.exitStub:
+// CHECK: ret void
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index a67c61f23631f..9a83b46efddca 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -26,40 +26,6 @@ llvm.func @atomic_hint(%v : !llvm.ptr, %x : !llvm.ptr, %expr : i32) {
// -----
-llvm.func @cancel_taskgroup() {
- // expected-error at below {{LLVM Translation failed for operation: omp.taskgroup}}
- omp.taskgroup {
- // expected-error at below {{LLVM Translation failed for operation: omp.task}}
- omp.task {
- // expected-error at below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancel operation}}
- // expected-error at below {{LLVM Translation failed for operation: omp.cancel}}
- omp.cancel cancellation_construct_type(taskgroup)
- omp.terminator
- }
- omp.terminator
- }
- llvm.return
-}
-
-// -----
-
-llvm.func @cancellation_point_taskgroup() {
- // expected-error at below {{LLVM Translation failed for operation: omp.taskgroup}}
- omp.taskgroup {
- // expected-error at below {{LLVM Translation failed for operation: omp.task}}
- omp.task {
- // expected-error at below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancellation_point operation}}
- // expected-error at below {{LLVM Translation failed for operation: omp.cancellation_point}}
- omp.cancellation_point cancellation_construct_type(taskgroup)
- omp.terminator
- }
- omp.terminator
- }
- llvm.return
-}
-
-// -----
-
llvm.func @do_simd(%lb : i32, %ub : i32, %step : i32) {
omp.wsloop {
// expected-warning at below {{simd information on composite construct discarded}}
More information about the Mlir-commits
mailing list