[clang] [llvm] [mlir] [OMPIRBuilder] always leave PARALLEL via the same barrier (PR #164586)

Tom Eccles via cfe-commits cfe-commits at lists.llvm.org
Wed Oct 22 02:26:54 PDT 2025


https://github.com/tblah created https://github.com/llvm/llvm-project/pull/164586

A barrier will pause execution until all threads reach it. If some go to a different barrier then we deadlock. This manifests in that the finalization callback must only be run once. Fix by ensuring we always go through the same finalization block whether the thread in cancelled or not and no matter which cancellation point causes the cancellation.

The old callback only affected PARALLEL, so it has been moved into the code generating PARALLEL. For this reason, we don't need similar changes for other cancellable constructs. We need to create the barrier on the shared exit from the outlined function instead of only on the cancelled branch to make sure that threads exiting normally (without cancellation) meet the same barriers as those which were cancelled. For example, previously we might have generated code like

```
...
  %ret = call i32 @__kmpc_cancel(...)
  %cond = icmp eq i32 %ret, 0
  br i1 %cond, label %continue, label %cancel

continue:
  // do the rest of the callback, eventually branching to %fini
  br label %fini

cancel:
  // Populated by the callback:
  // unsafe: if any thread makes it to the end without being cancelled
  // it won't reach this barrier and then the program will deadlock
  %unused = call i32 @__kmpc_cancel_barrier(...)
  br label %fini

fini:
  // run destructors etc
  ret
```

In the new version the barrier is moved into fini. I generate it *after* the destructors because the standard describes the barrier as occurring after the end of the parallel region.

```
...
  %ret = call i32 @__kmpc_cancel(...)
  %cond = icmp eq i32 %ret, 0
  br i1 %cond, label %continue, label %cancel

continue:
  // do the rest of the callback, eventually branching to %fini
  br label %fini

cancel:
  br label %fini

fini:
  // run destructors etc
  // safe so long as every exit from the function happens via this block:
  %unused = call i32 @__kmpc_cancel_barrier(...)
  ret
```

To achieve this, the barrier is now generated alongside the finalization code instead of in the callback. This is the reason for the changes to the unit test.

I'm unsure if I should keep the incorrect barrier generation callback only on the cancellation branch in clang with the OMPIRBuilder backend because that would match clang's ordinary codegen. Right now I have opted to remove it entirely because it is a deadlock waiting to happen.

>From 0f80c45bd5863cf11e051ba8701d532ded5a16fd Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Thu, 25 Sep 2025 11:12:24 +0000
Subject: [PATCH] [OMPIRBuilder] always leave PARALLEL via the same barrier

A barrier will pause execution until all threads reach it. If some go to
a different barrier then we deadlock. This manifests in that the
finalization callback must only be run once. Fix by ensuring we always
go through the same finalization block whether the thread in cancelled
or not and no matter which cancellation point causes the cancellation.

The old callback only affected PARALLEL, so it has been moved into the
code generating PARALLEL. For this reason, we don't need similar changes
for other cancellable constructs. We need to create the barrier on the
shared exit from the outlined function instead of only on the cancelled
branch to make sure that threads exiting normally (without cancellation)
meet the same barriers as those which were cancelled. For example,
previously we might have generated code like

```
...
  %ret = call i32 @__kmpc_cancel(...)
  %cond = icmp eq i32 %ret, 0
  br i1 %cond, label %continue, label %cancel

continue:
  // do the rest of the callback, eventually branching to %fini
  br label %fini

cancel:
  // Populated by the callback:
  // unsafe: if any thread makes it to the end without being cancelled
  // it won't reach this barrier and then the program will deadlock
  %unused = call i32 @__kmpc_cancel_barrier(...)
  br label %fini

fini:
  // run destructors etc
  ret
```

In the new version the barrier is moved into fini. I generate it *after*
the destructors because the standard describes the barrier as occurring
after the end of the parallel region.

```
...
  %ret = call i32 @__kmpc_cancel(...)
  %cond = icmp eq i32 %ret, 0
  br i1 %cond, label %continue, label %cancel

continue:
  // do the rest of the callback, eventually branching to %fini
  br label %fini

cancel:
  br label %fini

fini:
  // run destructors etc
  // safe so long as every exit from the function happens via this block:
  %unused = call i32 @__kmpc_cancel_barrier(...)
  ret
```

To achieve this, the barrier is now generated alongside the finalization
code instead of in the callback. This is the reason for the changes to
the unit test.

I'm unsure if I should keep the incorrect barrier generation callback only
on the cancellation branch in clang with the OMPIRBuilder backend
because that would match clang's ordinary codegen. Right now I have
opted to remove it entirely because it is a deadlock waiting to happen.
---
 clang/test/OpenMP/cancel_codegen.cpp          | 36 +++++-----
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  8 ++-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 72 ++++++++-----------
 .../Frontend/OpenMPIRBuilderTest.cpp          | 65 +++++++----------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 14 +++-
 .../Target/LLVMIR/openmp-barrier-cancel.mlir  | 14 ++--
 mlir/test/Target/LLVMIR/openmp-cancel.mlir    | 20 +++---
 .../LLVMIR/openmp-cancellation-point.mlir     | 12 ++--
 8 files changed, 124 insertions(+), 117 deletions(-)

diff --git a/clang/test/OpenMP/cancel_codegen.cpp b/clang/test/OpenMP/cancel_codegen.cpp
index 16e7542a8e826..150cdb9b2cc14 100644
--- a/clang/test/OpenMP/cancel_codegen.cpp
+++ b/clang/test/OpenMP/cancel_codegen.cpp
@@ -811,16 +811,16 @@ for (int i = 0; i < argc; ++i) {
 // CHECK3-NEXT:    br label [[OMP_SECTION_LOOP_BODY_CASE23_SECTION_AFTER:%.*]]
 // CHECK3:       omp_section_loop.body.case23.section.after:
 // CHECK3-NEXT:    br label [[OMP_SECTION_LOOP_BODY16_SECTIONS_AFTER]]
-// CHECK3:       omp_section_loop.body.case25:
+// CHECK3:       omp_section_loop.body.case26:
 // CHECK3-NEXT:    [[OMP_GLOBAL_THREAD_NUM27:%.*]] = call i32 @__kmpc_global_thread_num(ptr @[[GLOB1]])
 // CHECK3-NEXT:    [[TMP18:%.*]] = call i32 @__kmpc_cancel(ptr @[[GLOB1]], i32 [[OMP_GLOBAL_THREAD_NUM27]], i32 3)
 // CHECK3-NEXT:    [[TMP19:%.*]] = icmp eq i32 [[TMP18]], 0
 // CHECK3-NEXT:    br i1 [[TMP19]], label [[OMP_SECTION_LOOP_BODY_CASE25_SPLIT:%.*]], label [[OMP_SECTION_LOOP_BODY_CASE25_CNCL:%.*]]
-// CHECK3:       omp_section_loop.body.case25.split:
+// CHECK3:       omp_section_loop.body.case26.split:
 // CHECK3-NEXT:    br label [[OMP_SECTION_LOOP_BODY_CASE25_SECTION_AFTER26:%.*]]
-// CHECK3:       omp_section_loop.body.case25.section.after26:
+// CHECK3:       omp_section_loop.body.case26.section.after27:
 // CHECK3-NEXT:    br label [[OMP_SECTION_LOOP_BODY_CASE25_SECTION_AFTER:%.*]]
-// CHECK3:       omp_section_loop.body.case25.section.after:
+// CHECK3:       omp_section_loop.body.case26.section.after:
 // CHECK3-NEXT:    br label [[OMP_SECTION_LOOP_BODY16_SECTIONS_AFTER]]
 // CHECK3:       omp_section_loop.body16.sections.after:
 // CHECK3-NEXT:    br label [[OMP_SECTION_LOOP_INC17]]
@@ -891,10 +891,12 @@ for (int i = 0; i < argc; ++i) {
 // CHECK3:       .cancel.exit:
 // CHECK3-NEXT:    br label [[CANCEL_EXIT:%.*]]
 // CHECK3:       omp_section_loop.body.case.cncl:
-// CHECK3-NEXT:    br label [[OMP_SECTION_LOOP_EXIT]]
-// CHECK3:       omp_section_loop.body.case23.cncl:
+// CHECK3-NEXT:    br label [[FINI10:.*]]
+// CHECK3:       .fini25:
 // CHECK3-NEXT:    br label [[OMP_SECTION_LOOP_EXIT18]]
-// CHECK3:       omp_section_loop.body.case25.cncl:
+// CHECK3:       omp_section_loop.body.case26.cncl:
+// CHECK3-NEXT:    br label [[FINI29:.*]]
+// CHECK3:       .fini29:
 // CHECK3-NEXT:    br label [[OMP_SECTION_LOOP_EXIT18]]
 // CHECK3:       .cancel.continue:
 // CHECK3-NEXT:    br label [[OMP_IF_END:%.*]]
@@ -967,8 +969,10 @@ for (int i = 0; i < argc; ++i) {
 // CHECK3-NEXT:    [[TMP8:%.*]] = call i32 @__kmpc_cancel_barrier(ptr @[[GLOB3:[0-9]+]], i32 [[OMP_GLOBAL_THREAD_NUM4]])
 // CHECK3-NEXT:    [[TMP9:%.*]] = icmp eq i32 [[TMP8]], 0
 // CHECK3-NEXT:    br i1 [[TMP9]], label [[DOTCONT:%.*]], label [[DOTCNCL5:%.*]]
-// CHECK3:       .cncl5:
-// CHECK3-NEXT:    br label [[OMP_PAR_OUTLINED_EXIT_EXITSTUB:%.*]]
+// CHECK3:       .cncl4:
+// CHECK3-NEXT:    br label [[FINI:%.*]]
+// CHECK3:       .fini
+// CHECK3-NEXT:    br label %[[EXIT_STUB:omp.par.exit.exitStub]]
 // CHECK3:       .cont:
 // CHECK3-NEXT:    [[TMP10:%.*]] = load i32, ptr [[LOADGEP_ARGC_ADDR]], align 4
 // CHECK3-NEXT:    [[TMP11:%.*]] = load ptr, ptr [[LOADGEP_ARGV_ADDR]], align 8
@@ -984,16 +988,14 @@ for (int i = 0; i < argc; ++i) {
 // CHECK3:       omp.par.region.parallel.after:
 // CHECK3-NEXT:    br label [[OMP_PAR_PRE_FINALIZE:%.*]]
 // CHECK3:       omp.par.pre_finalize:
-// CHECK3-NEXT:    br label [[OMP_PAR_OUTLINED_EXIT_EXITSTUB]]
+// CHECK3-NEXT:    br label [[FINI]]
 // CHECK3:       14:
 // CHECK3-NEXT:    [[OMP_GLOBAL_THREAD_NUM1:%.*]] = call i32 @__kmpc_global_thread_num(ptr @[[GLOB1]])
 // CHECK3-NEXT:    [[TMP15:%.*]] = call i32 @__kmpc_cancel(ptr @[[GLOB1]], i32 [[OMP_GLOBAL_THREAD_NUM1]], i32 1)
 // CHECK3-NEXT:    [[TMP16:%.*]] = icmp eq i32 [[TMP15]], 0
 // CHECK3-NEXT:    br i1 [[TMP16]], label [[DOTSPLIT:%.*]], label [[DOTCNCL:%.*]]
 // CHECK3:       .cncl:
-// CHECK3-NEXT:    [[OMP_GLOBAL_THREAD_NUM2:%.*]] = call i32 @__kmpc_global_thread_num(ptr @[[GLOB1]])
-// CHECK3-NEXT:    [[TMP17:%.*]] = call i32 @__kmpc_cancel_barrier(ptr @[[GLOB2]], i32 [[OMP_GLOBAL_THREAD_NUM2]])
-// CHECK3-NEXT:    br label [[OMP_PAR_OUTLINED_EXIT_EXITSTUB]]
+// CHECK3-NEXT:    br label [[FINI]]
 // CHECK3:       .split:
 // CHECK3-NEXT:    br label [[TMP4]]
 // CHECK3:       omp.par.exit.exitStub:
@@ -1089,7 +1091,7 @@ for (int i = 0; i < argc; ++i) {
 // CHECK3:       .omp.sections.case.split:
 // CHECK3-NEXT:    br label [[DOTOMP_SECTIONS_EXIT]]
 // CHECK3:       .omp.sections.case.cncl:
-// CHECK3-NEXT:    br label [[CANCEL_CONT:%.*]]
+// CHECK3-NEXT:    br label [[FINI:%.*]]
 // CHECK3:       .omp.sections.exit:
 // CHECK3-NEXT:    br label [[OMP_INNER_FOR_INC:%.*]]
 // CHECK3:       omp.inner.for.inc:
@@ -1100,7 +1102,7 @@ for (int i = 0; i < argc; ++i) {
 // CHECK3:       omp.inner.for.end:
 // CHECK3-NEXT:    [[OMP_GLOBAL_THREAD_NUM3:%.*]] = call i32 @__kmpc_global_thread_num(ptr @[[GLOB19:[0-9]+]])
 // CHECK3-NEXT:    call void @__kmpc_for_static_fini(ptr @[[GLOB15]], i32 [[OMP_GLOBAL_THREAD_NUM3]])
-// CHECK3-NEXT:    br label [[CANCEL_CONT]]
+// CHECK3-NEXT:    br label [[CANCEL_CONT:.*]]
 // CHECK3:       cancel.cont:
 // CHECK3-NEXT:    ret void
 // CHECK3:       cancel.exit:
@@ -1153,6 +1155,8 @@ for (int i = 0; i < argc; ++i) {
 // CHECK3:       .omp.sections.case.split:
 // CHECK3-NEXT:    br label [[DOTOMP_SECTIONS_EXIT]]
 // CHECK3:       .omp.sections.case.cncl:
+// CHECK3-NEXT:    br label [[DOTFINI:.%*]]
+// CHECK3:       .fini:
 // CHECK3-NEXT:    br label [[CANCEL_CONT:%.*]]
 // CHECK3:       .omp.sections.case2:
 // CHECK3-NEXT:    [[OMP_GLOBAL_THREAD_NUM3:%.*]] = call i32 @__kmpc_global_thread_num(ptr @[[GLOB1]])
@@ -1164,7 +1168,7 @@ for (int i = 0; i < argc; ++i) {
 // CHECK3:       .omp.sections.case2.section.after:
 // CHECK3-NEXT:    br label [[DOTOMP_SECTIONS_EXIT]]
 // CHECK3:       .omp.sections.case2.cncl:
-// CHECK3-NEXT:    br label [[OMP_INNER_FOR_END]]
+// CHECK3-NEXT:    br label [[FINI:.*]]
 // CHECK3:       .omp.sections.exit:
 // CHECK3-NEXT:    br label [[OMP_INNER_FOR_INC:%.*]]
 // CHECK3:       omp.inner.for.inc:
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 5331cb5abdc6f..3658dfe263424 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -580,6 +580,11 @@ class OpenMPIRBuilder {
 
     /// Flag to indicate if the directive is cancellable.
     bool IsCancellable;
+
+    /// The basic block to which control should be transferred to
+    /// implement the FiniCB. Memoized to avoid generating finalization
+    /// multiple times.
+    llvm::BasicBlock *FiniBB = nullptr;
   };
 
   /// Push a finalization callback on the finalization stack.
@@ -2181,8 +2186,7 @@ class OpenMPIRBuilder {
   ///
   /// \return an error, if any were triggered during execution.
   LLVM_ABI Error emitCancelationCheckImpl(Value *CancelFlag,
-                                          omp::Directive CanceledDirective,
-                                          FinalizeCallbackTy ExitCB = {});
+                                          omp::Directive CanceledDirective);
 
   /// Generate a target region entry call.
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 286ed039b1214..9fe0c3f880480 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1108,21 +1108,9 @@ OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
   Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
   Value *Result = Builder.CreateCall(
       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancel), Args);
-  auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) -> Error {
-    if (CanceledDirective == OMPD_parallel) {
-      IRBuilder<>::InsertPointGuard IPG(Builder);
-      Builder.restoreIP(IP);
-      return createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
-                           omp::Directive::OMPD_unknown,
-                           /* ForceSimpleCall */ false,
-                           /* CheckCancelFlag */ false)
-          .takeError();
-    }
-    return Error::success();
-  };
 
   // The actual cancel logic is shared with others, e.g., cancel_barriers.
-  if (Error Err = emitCancelationCheckImpl(Result, CanceledDirective, ExitCB))
+  if (Error Err = emitCancelationCheckImpl(Result, CanceledDirective))
     return Err;
 
   // Update the insertion point and remove the terminator we introduced.
@@ -1159,21 +1147,9 @@ OpenMPIRBuilder::createCancellationPoint(const LocationDescription &Loc,
   Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
   Value *Result = Builder.CreateCall(
       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancellationpoint), Args);
-  auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) -> Error {
-    if (CanceledDirective == OMPD_parallel) {
-      IRBuilder<>::InsertPointGuard IPG(Builder);
-      Builder.restoreIP(IP);
-      return createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
-                           omp::Directive::OMPD_unknown,
-                           /* ForceSimpleCall */ false,
-                           /* CheckCancelFlag */ false)
-          .takeError();
-    }
-    return Error::success();
-  };
 
   // The actual cancel logic is shared with others, e.g., cancel_barriers.
-  if (Error Err = emitCancelationCheckImpl(Result, CanceledDirective, ExitCB))
+  if (Error Err = emitCancelationCheckImpl(Result, CanceledDirective))
     return Err;
 
   // Update the insertion point and remove the terminator we introduced.
@@ -1277,8 +1253,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitKernelLaunch(
 }
 
 Error OpenMPIRBuilder::emitCancelationCheckImpl(
-    Value *CancelFlag, omp::Directive CanceledDirective,
-    FinalizeCallbackTy ExitCB) {
+    Value *CancelFlag, omp::Directive CanceledDirective) {
   assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
          "Unexpected cancellation!");
 
@@ -1305,13 +1280,17 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl(
 
   // From the cancellation block we finalize all variables and go to the
   // post finalization block that is known to the FiniCB callback.
-  Builder.SetInsertPoint(CancellationBlock);
-  if (ExitCB)
-    if (Error Err = ExitCB(Builder.saveIP()))
-      return Err;
   auto &FI = FinalizationStack.back();
-  if (Error Err = FI.FiniCB(Builder.saveIP()))
-    return Err;
+  if (!FI.FiniBB) {
+    llvm::IRBuilderBase::InsertPointGuard Guard(Builder);
+    FI.FiniBB = BasicBlock::Create(BB->getContext(), ".fini", BB->getParent());
+    Builder.SetInsertPoint(FI.FiniBB);
+    // FiniCB adds the branch to the exit stub.
+    if (Error Err = FI.FiniCB(Builder.saveIP()))
+      return Err;
+  }
+  Builder.SetInsertPoint(CancellationBlock);
+  Builder.CreateBr(FI.FiniBB);
 
   // The continuation block is where code generation continues.
   Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
@@ -1800,8 +1779,18 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
   Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
 
   InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
-  if (Error Err = FiniCB(PreFiniIP))
-    return Err;
+  if (!FiniInfo.FiniBB) {
+    if (Error Err = FiniCB(PreFiniIP))
+      return Err;
+  } else {
+    llvm::IRBuilderBase::InsertPointGuard Guard{Builder};
+    Builder.restoreIP(PreFiniIP);
+    Builder.CreateBr(FiniInfo.FiniBB);
+    // There's currently a branch to omp.par.exit. Delete it. We will get there
+    // via the fini block
+    if (llvm::Instruction *Term = Builder.GetInsertBlock()->getTerminator())
+      Term->eraseFromParent();
+  }
 
   // Register the outlined info.
   addOutlineInfo(std::move(OI));
@@ -6556,13 +6545,14 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitCommonDirectiveExit(
     FinalizationInfo Fi = FinalizationStack.pop_back_val();
     assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
 
-    if (Error Err = Fi.FiniCB(FinIP))
-      return Err;
-
-    BasicBlock *FiniBB = FinIP.getBlock();
-    Instruction *FiniBBTI = FiniBB->getTerminator();
+    if (!Fi.FiniBB) {
+      if (Error Err = Fi.FiniCB(FinIP))
+        return Err;
+      Fi.FiniBB = FinIP.getBlock();
+    }
 
     // set Builder IP for call creation
+    Instruction *FiniBBTI = Fi.FiniBB->getTerminator();
     Builder.SetInsertPoint(FiniBBTI);
   }
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index e56872320b4ac..da1760a56d952 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -428,8 +428,8 @@ TEST_F(OpenMPIRBuilderTest, CreateCancel) {
                        OMPBuilder.createCancel(Loc, nullptr, OMPD_parallel));
   Builder.restoreIP(NewIP);
   EXPECT_FALSE(M->global_empty());
-  EXPECT_EQ(M->size(), 4U);
-  EXPECT_EQ(F->size(), 4U);
+  EXPECT_EQ(M->size(), 3U);
+  EXPECT_EQ(F->size(), 5U);
   EXPECT_EQ(BB->size(), 4U);
 
   CallInst *GTID = dyn_cast<CallInst>(&BB->front());
@@ -449,23 +449,16 @@ TEST_F(OpenMPIRBuilderTest, CreateCancel) {
   Instruction *CancelBBTI = Cancel->getParent()->getTerminator();
   EXPECT_EQ(CancelBBTI->getNumSuccessors(), 2U);
   EXPECT_EQ(CancelBBTI->getSuccessor(0), NewIP.getBlock());
-  EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 3U);
-  CallInst *GTID1 = dyn_cast<CallInst>(&CancelBBTI->getSuccessor(1)->front());
-  EXPECT_NE(GTID1, nullptr);
-  EXPECT_EQ(GTID1->arg_size(), 1U);
-  EXPECT_EQ(GTID1->getCalledFunction()->getName(), "__kmpc_global_thread_num");
-  EXPECT_FALSE(GTID1->getCalledFunction()->doesNotAccessMemory());
-  EXPECT_FALSE(GTID1->getCalledFunction()->doesNotFreeMemory());
-  CallInst *Barrier = dyn_cast<CallInst>(GTID1->getNextNode());
-  EXPECT_NE(Barrier, nullptr);
-  EXPECT_EQ(Barrier->arg_size(), 2U);
-  EXPECT_EQ(Barrier->getCalledFunction()->getName(), "__kmpc_cancel_barrier");
-  EXPECT_FALSE(Barrier->getCalledFunction()->doesNotAccessMemory());
-  EXPECT_FALSE(Barrier->getCalledFunction()->doesNotFreeMemory());
-  EXPECT_TRUE(Barrier->use_empty());
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 1U);
   EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(),
             1U);
-  EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0), CBB);
+  // cancel branch instruction (1) -> .cncl -> .fini -> CBB
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)
+                ->getTerminator()
+                ->getSuccessor(0)
+                ->getTerminator()
+                ->getSuccessor(0),
+            CBB);
 
   EXPECT_EQ(cast<CallInst>(Cancel)->getArgOperand(1), GTID);
 
@@ -497,8 +490,8 @@ TEST_F(OpenMPIRBuilderTest, CreateCancelIfCond) {
       OMPBuilder.createCancel(Loc, Builder.getTrue(), OMPD_parallel));
   Builder.restoreIP(NewIP);
   EXPECT_FALSE(M->global_empty());
-  EXPECT_EQ(M->size(), 4U);
-  EXPECT_EQ(F->size(), 7U);
+  EXPECT_EQ(M->size(), 3U);
+  EXPECT_EQ(F->size(), 8U);
   EXPECT_EQ(BB->size(), 1U);
   ASSERT_TRUE(isa<BranchInst>(BB->getTerminator()));
   ASSERT_EQ(BB->getTerminator()->getNumSuccessors(), 2U);
@@ -524,23 +517,15 @@ TEST_F(OpenMPIRBuilderTest, CreateCancelIfCond) {
   EXPECT_EQ(CancelBBTI->getSuccessor(0)->size(), 1U);
   EXPECT_EQ(CancelBBTI->getSuccessor(0)->getUniqueSuccessor(),
             NewIP.getBlock());
-  EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 3U);
-  CallInst *GTID1 = dyn_cast<CallInst>(&CancelBBTI->getSuccessor(1)->front());
-  EXPECT_NE(GTID1, nullptr);
-  EXPECT_EQ(GTID1->arg_size(), 1U);
-  EXPECT_EQ(GTID1->getCalledFunction()->getName(), "__kmpc_global_thread_num");
-  EXPECT_FALSE(GTID1->getCalledFunction()->doesNotAccessMemory());
-  EXPECT_FALSE(GTID1->getCalledFunction()->doesNotFreeMemory());
-  CallInst *Barrier = dyn_cast<CallInst>(GTID1->getNextNode());
-  EXPECT_NE(Barrier, nullptr);
-  EXPECT_EQ(Barrier->arg_size(), 2U);
-  EXPECT_EQ(Barrier->getCalledFunction()->getName(), "__kmpc_cancel_barrier");
-  EXPECT_FALSE(Barrier->getCalledFunction()->doesNotAccessMemory());
-  EXPECT_FALSE(Barrier->getCalledFunction()->doesNotFreeMemory());
-  EXPECT_TRUE(Barrier->use_empty());
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 1U);
   EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(),
             1U);
-  EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0), CBB);
+  EXPECT_EQ(CancelBBTI->getSuccessor(1)
+                ->getTerminator()
+                ->getSuccessor(0)
+                ->getTerminator()
+                ->getSuccessor(0),
+            CBB);
 
   EXPECT_EQ(cast<CallInst>(Cancel)->getArgOperand(1), GTID);
 
@@ -572,7 +557,7 @@ TEST_F(OpenMPIRBuilderTest, CreateCancelBarrier) {
   Builder.restoreIP(NewIP);
   EXPECT_FALSE(M->global_empty());
   EXPECT_EQ(M->size(), 3U);
-  EXPECT_EQ(F->size(), 4U);
+  EXPECT_EQ(F->size(), 5U);
   EXPECT_EQ(BB->size(), 4U);
 
   CallInst *GTID = dyn_cast<CallInst>(&BB->front());
@@ -595,7 +580,11 @@ TEST_F(OpenMPIRBuilderTest, CreateCancelBarrier) {
   EXPECT_EQ(BarrierBBTI->getSuccessor(1)->size(), 1U);
   EXPECT_EQ(BarrierBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(),
             1U);
-  EXPECT_EQ(BarrierBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0),
+  EXPECT_EQ(BarrierBBTI->getSuccessor(1)
+                ->getTerminator()
+                ->getSuccessor(0)
+                ->getTerminator()
+                ->getSuccessor(0),
             CBB);
 
   EXPECT_EQ(cast<CallInst>(Barrier)->getArgOperand(1), GTID);
@@ -1291,8 +1280,8 @@ TEST_F(OpenMPIRBuilderTest, ParallelCancelBarrier) {
 
   EXPECT_EQ(NumBodiesGenerated, 1U);
   EXPECT_EQ(NumPrivatizedVars, 0U);
-  EXPECT_EQ(NumFinalizationPoints, 2U);
-  EXPECT_TRUE(FakeDestructor->hasNUses(2));
+  EXPECT_EQ(NumFinalizationPoints, 1U);
+  EXPECT_TRUE(FakeDestructor->hasNUses(1));
 
   Builder.restoreIP(AfterIP);
   Builder.CreateRetVoid();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8de49dd397d26..1629be2ba46c4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2662,6 +2662,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
   assert(isByRef.size() == opInst.getNumReductionVars());
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  bool isCancellable = constructIsCancellable(opInst);
 
   if (failed(checkImplementationStatus(*opInst)))
     return failure();
@@ -2797,6 +2798,18 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
                                   privateVarsInfo.privatizers)))
       return llvm::make_error<PreviouslyReportedError>();
 
+    // If we could be performing cancellation, add the cancellation barrier on
+    // the way out of the outlined region.
+    if (isCancellable) {
+      auto IPOrErr = ompBuilder->createBarrier(
+          llvm::OpenMPIRBuilder::LocationDescription(builder),
+          llvm::omp::Directive::OMPD_unknown,
+          /* ForceSimpleCall */ false,
+          /* CheckCancelFlag */ false);
+      if (!IPOrErr)
+        return IPOrErr.takeError();
+    }
+
     builder.restoreIP(oldIP);
     return llvm::Error::success();
   };
@@ -2810,7 +2823,6 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
   if (auto bind = opInst.getProcBindKind())
     pbKind = getProcBindKind(*bind);
-  bool isCancellable = constructIsCancellable(opInst);
 
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
diff --git a/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir
index c4b245667a1f3..6585549de7f96 100644
--- a/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir
@@ -29,22 +29,24 @@ llvm.func @test() {
 // CHECK:         %[[VAL_14:.*]] = icmp eq i32 %[[VAL_13]], 0
 // CHECK:         br i1 %[[VAL_14]], label %[[VAL_15:.*]], label %[[VAL_16:.*]]
 // CHECK:       omp.par.region1.cncl:                             ; preds = %[[VAL_11]]
-// CHECK:         %[[VAL_17:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
-// CHECK:         %[[VAL_18:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_17]])
-// CHECK:         br label %[[VAL_19:.*]]
+// CHECK:         br label %[[FINI:.*]]
+// CHECK:       .fini:
+// CHECK:         %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK:         %[[CNCL_BARRIER:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[TID]])
+// CHECK:         br label %[[EXIT_STUB:.*]]
 // CHECK:       omp.par.region1.split:                            ; preds = %[[VAL_11]]
 // CHECK:         %[[VAL_20:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
 // CHECK:         %[[VAL_21:.*]] = call i32 @__kmpc_cancel_barrier(ptr @3, i32 %[[VAL_20]])
 // CHECK:         %[[VAL_22:.*]] = icmp eq i32 %[[VAL_21]], 0
 // CHECK:         br i1 %[[VAL_22]], label %[[VAL_23:.*]], label %[[VAL_24:.*]]
 // CHECK:       omp.par.region1.split.cncl:                       ; preds = %[[VAL_15]]
-// CHECK:         br label %[[VAL_19]]
+// CHECK:         br label %[[FINI]]
 // CHECK:       omp.par.region1.split.cont:                       ; preds = %[[VAL_15]]
 // CHECK:         br label %[[VAL_25:.*]]
 // CHECK:       omp.region.cont:                                  ; preds = %[[VAL_23]]
 // CHECK:         br label %[[VAL_26:.*]]
 // CHECK:       omp.par.pre_finalize:                             ; preds = %[[VAL_25]]
-// CHECK:         br label %[[VAL_19]]
-// CHECK:       omp.par.exit.exitStub:                            ; preds = %[[VAL_26]], %[[VAL_24]], %[[VAL_16]]
+// CHECK:         br label %[[FINI]]
+// CHECK:       omp.par.exit.exitStub:
 // CHECK:         ret void
 
diff --git a/mlir/test/Target/LLVMIR/openmp-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-cancel.mlir
index 21241702ad569..035e1f546bafa 100644
--- a/mlir/test/Target/LLVMIR/openmp-cancel.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-cancel.mlir
@@ -24,16 +24,18 @@ llvm.func @cancel_parallel() {
 // CHECK:         %[[VAL_15:.*]] = icmp eq i32 %[[VAL_14]], 0
 // CHECK:         br i1 %[[VAL_15]], label %[[VAL_16:.*]], label %[[VAL_17:.*]]
 // CHECK:       omp.par.region1.cncl:                             ; preds = %[[VAL_12]]
+// CHECK:         br label %[[VAL_20:.*]]
+// CHECK:       .fini:
 // CHECK:         %[[VAL_18:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
 // CHECK:         %[[VAL_19:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_18]])
-// CHECK:         br label %[[VAL_20:.*]]
+// CHECK:         br label %[[EXIT_STUB:.*]]
 // CHECK:       omp.par.region1.split:                            ; preds = %[[VAL_12]]
 // CHECK:         br label %[[VAL_21:.*]]
 // CHECK:       omp.region.cont:                                  ; preds = %[[VAL_16]]
 // CHECK:         br label %[[VAL_22:.*]]
 // CHECK:       omp.par.pre_finalize:                             ; preds = %[[VAL_21]]
 // CHECK:         br label %[[VAL_20]]
-// CHECK:       omp.par.exit.exitStub:                            ; preds = %[[VAL_22]], %[[VAL_17]]
+// CHECK:       omp.par.exit.exitStub:
 // CHECK:         ret void
 
 llvm.func @cancel_parallel_if(%arg0 : i1) {
@@ -67,18 +69,20 @@ llvm.func @cancel_parallel_if(%arg0 : i1) {
 // CHECK:         br label %[[VAL_26:.*]]
 // CHECK:       omp.par.pre_finalize:                             ; preds = %[[VAL_25]]
 // CHECK:         br label %[[VAL_27:.*]]
-// CHECK:       5:                                                ; preds = %[[VAL_20]]
+// CHECK:       .fini:
+// CHECK:         %[[VAL_32:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK:         %[[VAL_33:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_32]])
+// CHECK:         br label %[[EXIT_STUB:.*]]
+// CHECK:       6:                                                ; preds = %[[VAL_20]]
 // CHECK:         %[[VAL_28:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
 // CHECK:         %[[VAL_29:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_28]], i32 1)
 // CHECK:         %[[VAL_30:.*]] = icmp eq i32 %[[VAL_29]], 0
 // CHECK:         br i1 %[[VAL_30]], label %[[VAL_24]], label %[[VAL_31:.*]]
 // CHECK:       .cncl:                                            ; preds = %[[VAL_21]]
-// CHECK:         %[[VAL_32:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
-// CHECK:         %[[VAL_33:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_32]])
 // CHECK:         br label %[[VAL_27]]
 // CHECK:       .split:                                           ; preds = %[[VAL_21]]
 // CHECK:         br label %[[VAL_23]]
-// CHECK:       omp.par.exit.exitStub:                            ; preds = %[[VAL_31]], %[[VAL_26]]
+// CHECK:       omp.par.exit.exitStub:
 // CHECK:         ret void
 
 llvm.func @cancel_sections_if(%cond : i1) {
@@ -145,7 +149,7 @@ llvm.func @cancel_sections_if(%cond : i1) {
 // CHECK:       omp_section_loop.inc:                             ; preds = %[[VAL_23]]
 // CHECK:         %[[VAL_15]] = add nuw i32 %[[VAL_14]], 1
 // CHECK:         br label %[[VAL_12]]
-// CHECK:       omp_section_loop.exit:                            ; preds = %[[VAL_33]], %[[VAL_16]]
+// CHECK:       omp_section_loop.exit:
 // CHECK:         call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_7]])
 // CHECK:         %[[VAL_36:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
 // CHECK:         call void @__kmpc_barrier(ptr @2, i32 %[[VAL_36]])
@@ -232,7 +236,7 @@ llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
 // CHECK:       omp_loop.inc:                                     ; preds = %[[VAL_52]]
 // CHECK:         %[[VAL_34]] = add nuw i32 %[[VAL_33]], 1
 // CHECK:         br label %[[VAL_31]]
-// CHECK:       omp_loop.exit:                                    ; preds = %[[VAL_50]], %[[VAL_35]]
+// CHECK:       omp_loop.exit:
 // CHECK:         call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_26]])
 // CHECK:         %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
 // CHECK:         call void @__kmpc_barrier(ptr @2, i32 %[[VAL_53]])
diff --git a/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir
index 5e0d3f9f7e293..eeac1344cd7fb 100644
--- a/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir
@@ -24,16 +24,18 @@ llvm.func @cancellation_point_parallel() {
 // CHECK:         %[[VAL_15:.*]] = icmp eq i32 %[[VAL_14]], 0
 // CHECK:         br i1 %[[VAL_15]], label %[[VAL_16:.*]], label %[[VAL_17:.*]]
 // CHECK:       omp.par.region1.cncl:                             ; preds = %[[VAL_12]]
+// CHECK:         br label %[[FINI:.*]]
+// CHECK:       .fini:
 // CHECK:         %[[VAL_18:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
 // CHECK:         %[[VAL_19:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_18]])
-// CHECK:         br label %[[VAL_20:.*]]
+// CHECK:         br label %[[EXIT_STUB:.*]]
 // CHECK:       omp.par.region1.split:                            ; preds = %[[VAL_12]]
 // CHECK:         br label %[[VAL_21:.*]]
 // CHECK:       omp.region.cont:                                  ; preds = %[[VAL_16]]
 // CHECK:         br label %[[VAL_22:.*]]
 // CHECK:       omp.par.pre_finalize:                             ; preds = %[[VAL_21]]
-// CHECK:         br label %[[VAL_20]]
-// CHECK:       omp.par.exit.exitStub:                            ; preds = %[[VAL_22]], %[[VAL_17]]
+// CHECK:         br label %[[FINI]]
+// CHECK:       omp.par.exit.exitStub:
 // CHECK:         ret void
 
 llvm.func @cancellation_point_sections() {
@@ -94,7 +96,7 @@ llvm.func @cancellation_point_sections() {
 // CHECK:       omp_section_loop.inc:                             ; preds = %[[VAL_46]]
 // CHECK:         %[[VAL_38]] = add nuw i32 %[[VAL_37]], 1
 // CHECK:         br label %[[VAL_35]]
-// CHECK:       omp_section_loop.exit:                            ; preds = %[[VAL_53]], %[[VAL_39]]
+// CHECK:       omp_section_loop.exit:
 // CHECK:         call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_30]])
 // CHECK:         %[[VAL_55:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
 // CHECK:         call void @__kmpc_barrier(ptr @2, i32 %[[VAL_55]])
@@ -175,7 +177,7 @@ llvm.func @cancellation_point_wsloop(%lb : i32, %ub : i32, %step : i32) {
 // CHECK:       omp_loop.inc:                                     ; preds = %[[VAL_106]]
 // CHECK:         %[[VAL_92]] = add nuw i32 %[[VAL_91]], 1
 // CHECK:         br label %[[VAL_89]]
-// CHECK:       omp_loop.exit:                                    ; preds = %[[VAL_105]], %[[VAL_93]]
+// CHECK:       omp_loop.exit:
 // CHECK:         call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_84]])
 // CHECK:         %[[VAL_107:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
 // CHECK:         call void @__kmpc_barrier(ptr @2, i32 %[[VAL_107]])



More information about the cfe-commits mailing list