[flang-commits] [flang] [mlir] [OpenMP][flang][MLIR] Decouple alloc, init, and copy regions for `omp.private|declare_reduction` ops (PR #125699)

Kareem Ergawy via flang-commits flang-commits at lists.llvm.org
Wed Feb 5 06:14:21 PST 2025


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/125699

>From b7cfed0c0b699067bd2aaf8236e427d611819c2e Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Tue, 4 Feb 2025 08:55:39 -0600
Subject: [PATCH] [OpenMP][flang][MLIR] Decouple alloc, init, and copy regions
 for `omp.private|reduction` ops

This PR changes the emitted block structure of alloc, init, and copy
regions for `omp.private` and `omp.declare_reduction` ops a little bit.
In particular, this decouples init and copy regions from the alloca
insertion-point. The main motivation is fix "Instruction does not dominate
all uses!" errors that happen specially when an init region uses a value
from the OpenMP region it is being inlined into. The issue happens
because, previous to this PR, we inline the init region right after the
latest alloc block (since we used the alloca IP); which in some cases
(see exmaple below), is too early and causes the use dominance issue.

Example that would break without this PR (when delayed privatization is
enabled for `omp.wsloop`s):
```fortran
subroutine test2 (xyz)
  integer :: i
  integer :: xyz(:)

  !$omp target map(from:xyz)
    !$omp do private(xyz)
      do i = 1, 10
        xyz(i) = i
      end do
  !$omp end target
end subroutine
```
---
 .../parallel-private-reduction-worstcase.f90  |  64 +++---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 206 +++++++++++-------
 .../Target/LLVMIR/openmp-firstprivate.mlir    |  11 +-
 mlir/test/Target/LLVMIR/openmp-llvm.mlir      |  17 +-
 mlir/test/Target/LLVMIR/openmp-private.mlir   |   6 +-
 .../LLVMIR/openmp-wsloop-private-cond_br.mlir |   6 +-
 6 files changed, 185 insertions(+), 125 deletions(-)

diff --git a/flang/test/Integration/OpenMP/parallel-private-reduction-worstcase.f90 b/flang/test/Integration/OpenMP/parallel-private-reduction-worstcase.f90
index 6facce56123ab2a..7e735b64995041e 100644
--- a/flang/test/Integration/OpenMP/parallel-private-reduction-worstcase.f90
+++ b/flang/test/Integration/OpenMP/parallel-private-reduction-worstcase.f90
@@ -32,46 +32,52 @@ subroutine worst_case(a, b, c, d)
 ! CHECK-LABEL: define internal void @worst_case_..omp_par
 ! CHECK-NEXT:  omp.par.entry:
 !                [reduction alloc regions inlined here]
-! CHECK:         br label %omp.private.init
+! CHECK:         br label %omp.region.after_alloca
 
-! CHECK:       omp.private.init:                            ; preds = %omp.par.entry
-! CHECK-NEXT:  br label %omp.private.init7
+! CHECK:       omp.region.after_alloca:
+! CHECK-NEXT:    br label %omp.par.region
+
+! CHECK:       omp.par.region:
+! CHECK-NEXT:    br label %omp.private.init
+
+! CHECK:       omp.private.init:
+! CHECK-NEXT:  br label %omp.private.init2
 
-! CHECK:       omp.private.init7:                               ; preds = %omp.private.init
+! CHECK:       omp.private.init2:
 !                [begin private alloc for first var]
 !                [read the length from the mold argument]
 !                [if it is non-zero...]
-! CHECK:         br i1 {{.*}}, label %omp.private.init8, label %omp.private.init9
+! CHECK:         br i1 %{{.*}}, label %omp.private.init3, label %omp.private.init4
 
-! CHECK:       omp.private.init9:                               ; preds = %omp.private.init7
-!                [finish private alloc for first var with zero extent]
-! CHECK:         br label %omp.private.init10
+! CHECK:       omp.private.init4:                               ; preds = %omp.private.init2
+!                [finish private alloc for second var with zero extent]
+! CHECK:         br label %omp.private.init5
 
-! CHECK:       omp.private.init10:                               ; preds = %omp.private.init8, %omp.private.init9
-! CHECK-NEXT:    br label %omp.region.cont6
+! CHECK:       omp.private.init5:                               ; preds = %omp.private.init3, %omp.private.init4
+! CHECK-NEXT:    br label %omp.region.cont
 
-! CHECK:       omp.region.cont6:                                 ; preds = %omp.private.init10
+! CHECK:       omp.region.cont:                                  ; preds = %omp.private.init5
 ! CHECK-NEXT:    %{{.*}} = phi ptr
-! CHECK-NEXT:    br label %omp.private.init1
+! CHECK-NEXT:    br label %omp.private.init7
 
-! CHECK:       omp.private.init1:                                ; preds = %omp.region.cont6
+! CHECK:       omp.private.init7:
 !                [begin private alloc for first var]
 !                [read the length from the mold argument]
 !                [if it is non-zero...]
-! CHECK:         br i1 %{{.*}}, label %omp.private.init2, label %omp.private.init3
+! CHECK:         br i1 {{.*}}, label %omp.private.init8, label %omp.private.init9
 
-! CHECK:       omp.private.init3:                               ; preds = %omp.private.init1
-!                [finish private alloc for second var with zero extent]
-! CHECK:         br label %omp.private.init4
+! CHECK:       omp.private.init9:                               ; preds = %omp.private.init7
+!                [finish private alloc for first var with zero extent]
+! CHECK:         br label %omp.private.init10
 
-! CHECK:       omp.private.init4:                               ; preds = %omp.private.init2, %omp.private.init3
-! CHECK-NEXT:    br label %omp.region.cont
+! CHECK:       omp.private.init10:                               ; preds = %omp.private.init8, %omp.private.init9
+! CHECK-NEXT:    br label %omp.region.cont6
 
-! CHECK:       omp.region.cont:                                  ; preds = %omp.private.init4
+! CHECK:       omp.region.cont6:                                 ; preds = %omp.private.init10
 ! CHECK-NEXT:    %{{.*}} = phi ptr
 ! CHECK-NEXT:    br label %omp.private.copy
 
-! CHECK:       omp.private.copy:                                 ; preds = %omp.region.cont
+! CHECK:       omp.private.copy:
 ! CHECK-NEXT:    br label %omp.private.copy12
 
 ! CHECK:       omp.private.copy12:                               ; preds = %omp.private.copy
@@ -96,15 +102,9 @@ subroutine worst_case(a, b, c, d)
 
 ! CHECK:       omp.region.cont15:                                ; preds = %omp.private.copy18
 ! CHECK-NEXT:    %{{.*}} = phi ptr
-! CHECK-NEXT:    br label %omp.region.after_alloca
-
-! CHECK:       omp.region.after_alloca:
-! CHECK-NEXT:    br label %omp.par.region
-
-! CHECK:       omp.par.region:                                   ; preds = %omp.region.after_alloca
 ! CHECK-NEXT:    br label %omp.reduction.init
 
-! CHECK:       omp.reduction.init:                               ; preds = %omp.par.region
+! CHECK:       omp.reduction.init:                               ; preds = %omp.region.cont15
 !                [deffered stores for results of reduction alloc regions]
 ! CHECK:         br label %[[VAL_96:.*]]
 
@@ -211,13 +211,13 @@ subroutine worst_case(a, b, c, d)
 !                [source length was non-zero: call assign runtime]
 ! CHECK:         br label %omp.private.copy14
 
-! CHECK:       omp.private.init2:                               ; preds = %omp.private.init1
-!                [var extent was non-zero: malloc a private array]
-! CHECK:         br label %omp.private.init4
-
 ! CHECK:       omp.private.init8:                               ; preds = %omp.private.init7
 !                [var extent was non-zero: malloc a private array]
 ! CHECK:         br label %omp.private.init10
 
+! CHECK:       omp.private.init3:                               ; preds = %omp.private.init2
+!                [var extent was non-zero: malloc a private array]
+! CHECK:         br label %omp.private.init5
+
 ! CHECK:       omp.par.outlined.exit.exitStub:                   ; preds = %omp.region.cont52
 ! CHECK-NEXT:    ret void
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index ea044fe0c8c196c..2e1f29cbf35e50a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1025,6 +1025,18 @@ mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation,
   }
 }
 
+static void
+setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder,
+                                    llvm::BasicBlock *block = nullptr) {
+  if (block == nullptr)
+    block = builder.GetInsertBlock();
+
+  if (block->empty() || block->getTerminator() == nullptr)
+    builder.SetInsertPoint(block);
+  else
+    builder.SetInsertPoint(block->getTerminator());
+}
+
 /// Inline reductions' `init` regions. This functions assumes that the
 /// `builder`'s insertion point is where the user wants the `init` regions to be
 /// inlined; i.e. it does not try to find a proper insertion location for the
@@ -1063,10 +1075,7 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
     }
   }
 
-  if (initBlock->empty() || initBlock->getTerminator() == nullptr)
-    builder.SetInsertPoint(initBlock);
-  else
-    builder.SetInsertPoint(initBlock->getTerminator());
+  setInsertPointForPossiblyEmptyBlock(builder, initBlock);
 
   // store result of the alloc region to the allocated pointer to the real
   // reduction variable
@@ -1091,11 +1100,7 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
     assert(phis.size() == 1 && "expected one value to be yielded from the "
                                "reduction neutral element declaration region");
 
-    if (builder.GetInsertBlock()->empty() ||
-        builder.GetInsertBlock()->getTerminator() == nullptr)
-      builder.SetInsertPoint(builder.GetInsertBlock());
-    else
-      builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
+    setInsertPointForPossiblyEmptyBlock(builder);
 
     if (isByRef[i]) {
       if (!reductionDecls[i].getAllocRegion().empty())
@@ -1331,18 +1336,15 @@ findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
 
 /// Initialize a single (first)private variable. You probably want to use
 /// allocateAndInitPrivateVars instead of this.
-static llvm::Error
-initPrivateVar(llvm::IRBuilderBase &builder,
-               LLVM::ModuleTranslation &moduleTranslation,
-               omp::PrivateClauseOp &privDecl, Value mlirPrivVar,
-               BlockArgument &blockArg, llvm::Value *llvmPrivateVar,
-               llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
-               llvm::BasicBlock *privInitBlock,
-               llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
+static llvm::Error initPrivateVar(
+    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
+    omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
+    llvm::SmallVectorImpl<llvm::Value *>::iterator llvmPrivateVarIt,
+    llvm::BasicBlock *privInitBlock,
+    llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
   Region &initRegion = privDecl.getInitRegion();
   if (initRegion.empty()) {
-    moduleTranslation.mapValue(blockArg, llvmPrivateVar);
-    llvmPrivateVars.push_back(llvmPrivateVar);
+    moduleTranslation.mapValue(blockArg, *llvmPrivateVarIt);
     return llvm::Error::success();
   }
 
@@ -1351,11 +1353,10 @@ initPrivateVar(llvm::IRBuilderBase &builder,
       mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
   assert(nonPrivateVar);
   moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
-  moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
+  moduleTranslation.mapValue(privDecl.getInitPrivateArg(), *llvmPrivateVarIt);
 
   // in-place convert the private initialization region
   SmallVector<llvm::Value *, 1> phis;
-  builder.SetInsertPoint(privInitBlock->getTerminator());
   if (failed(inlineConvertOmpRegions(initRegion, "omp.private.init", builder,
                                      moduleTranslation, &phis)))
     return llvm::createStringError(
@@ -1367,7 +1368,7 @@ initPrivateVar(llvm::IRBuilderBase &builder,
   // variable in case the region is operating on arguments by-value (e.g.
   // Fortran character boxes).
   moduleTranslation.mapValue(blockArg, phis[0]);
-  llvmPrivateVars.push_back(phis[0]);
+  *llvmPrivateVarIt = phis[0];
 
   // clear init region block argument mapping in case it needs to be
   // re-created with a different source for another use of the same
@@ -1376,17 +1377,48 @@ initPrivateVar(llvm::IRBuilderBase &builder,
   return llvm::Error::success();
 }
 
+static llvm::Error
+initPrivateVars(llvm::IRBuilderBase &builder,
+                LLVM::ModuleTranslation &moduleTranslation,
+                MutableArrayRef<BlockArgument> privateBlockArgs,
+                MutableArrayRef<omp::PrivateClauseOp> privateDecls,
+                MutableArrayRef<mlir::Value> mlirPrivateVars,
+                llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
+                llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
+  if (privateBlockArgs.empty())
+    return llvm::Error::success();
+
+  llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
+  setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
+
+  for (auto [idx, zip] : llvm::enumerate(
+           llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs))) {
+    auto [privDecl, mlirPrivVar, blockArg] = zip;
+    llvm::Error err = initPrivateVar(
+        builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
+        llvmPrivateVars.begin() + idx, privInitBlock, mappedPrivateVars);
+
+    if (err)
+      return err;
+
+    setInsertPointForPossiblyEmptyBlock(builder);
+  }
+
+  return llvm::Error::success();
+}
+
 /// Allocate and initialize delayed private variables. Returns the basic block
 /// which comes after all of these allocations. llvm::Value * for each of these
 /// private variables are populated in llvmPrivateVars.
-static llvm::Expected<llvm::BasicBlock *> allocateAndInitPrivateVars(
-    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
-    MutableArrayRef<BlockArgument> privateBlockArgs,
-    MutableArrayRef<omp::PrivateClauseOp> privateDecls,
-    MutableArrayRef<mlir::Value> mlirPrivateVars,
-    llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
-    const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
-    llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
+static llvm::Expected<llvm::BasicBlock *>
+allocatePrivateVars(llvm::IRBuilderBase &builder,
+                    LLVM::ModuleTranslation &moduleTranslation,
+                    MutableArrayRef<BlockArgument> privateBlockArgs,
+                    MutableArrayRef<omp::PrivateClauseOp> privateDecls,
+                    MutableArrayRef<mlir::Value> mlirPrivateVars,
+                    llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
+                    const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
+                    llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
   // Allocate private vars
   llvm::BranchInst *allocaTerminator =
       llvm::cast<llvm::BranchInst>(allocaIP.getBlock()->getTerminator());
@@ -1404,9 +1436,6 @@ static llvm::Expected<llvm::BasicBlock *> allocateAndInitPrivateVars(
 
   llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
 
-  llvm::BasicBlock *privInitBlock = nullptr;
-  if (!privateBlockArgs.empty())
-    privInitBlock = splitBB(builder, true, "omp.private.init");
   for (auto [privDecl, mlirPrivVar, blockArg] :
        llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs)) {
     llvm::Type *llvmAllocType =
@@ -1414,24 +1443,18 @@ static llvm::Expected<llvm::BasicBlock *> allocateAndInitPrivateVars(
     builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
     llvm::Value *llvmPrivateVar = builder.CreateAlloca(
         llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
-
-    llvm::Error err = initPrivateVar(
-        builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
-        llvmPrivateVar, llvmPrivateVars, privInitBlock, mappedPrivateVars);
-    if (err)
-      return err;
+    llvmPrivateVars.push_back(llvmPrivateVar);
   }
+
   return afterAllocas;
 }
 
 static LogicalResult
-initFirstPrivateVars(llvm::IRBuilderBase &builder,
+copyFirstPrivateVars(llvm::IRBuilderBase &builder,
                      LLVM::ModuleTranslation &moduleTranslation,
                      SmallVectorImpl<mlir::Value> &mlirPrivateVars,
                      SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
-                     SmallVectorImpl<omp::PrivateClauseOp> &privateDecls,
-                     llvm::BasicBlock *afterAllocas) {
-  llvm::IRBuilderBase::InsertPointGuard guard(builder);
+                     SmallVectorImpl<omp::PrivateClauseOp> &privateDecls) {
   // Apply copy region for firstprivate.
   bool needsFirstprivate =
       llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
@@ -1442,13 +1465,9 @@ initFirstPrivateVars(llvm::IRBuilderBase &builder,
   if (!needsFirstprivate)
     return success();
 
-  assert(afterAllocas->getSinglePredecessor());
-
-  // Find the end of the allocation blocks
-  builder.SetInsertPoint(afterAllocas->getSinglePredecessor()->getTerminator());
   llvm::BasicBlock *copyBlock =
       splitBB(builder, /*CreateBranch=*/true, "omp.private.copy");
-  builder.SetInsertPoint(copyBlock->getFirstNonPHIOrDbgOrAlloca());
+  setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
 
   for (auto [decl, mlirVar, llvmVar] :
        llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) {
@@ -1467,11 +1486,12 @@ initFirstPrivateVars(llvm::IRBuilderBase &builder,
     moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
 
     // in-place convert copy region
-    builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
     if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder,
                                        moduleTranslation)))
       return decl.emitError("failed to inline `copy` region of `omp.private`");
 
+    setInsertPointForPossiblyEmptyBlock(builder);
+
     // ignore unused value yielded from copy region
 
     // clear copy region block argument mapping in case it needs to be
@@ -1757,20 +1777,25 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
     LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
         moduleTranslation, allocaIP);
 
-    llvm::Expected<llvm::BasicBlock *> afterAllocas =
-        allocateAndInitPrivateVars(builder, moduleTranslation, privateBlockArgs,
-                                   privateDecls, mlirPrivateVars,
-                                   llvmPrivateVars, allocaIP);
+    llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
+        builder, moduleTranslation, privateBlockArgs, privateDecls,
+        mlirPrivateVars, llvmPrivateVars, allocaIP);
     if (handleError(afterAllocas, *taskOp).failed())
       return llvm::make_error<PreviouslyReportedError>();
 
-    if (failed(initFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
-                                    llvmPrivateVars, privateDecls,
-                                    afterAllocas.get())))
+    builder.restoreIP(codegenIP);
+    if (handleError(initPrivateVars(builder, moduleTranslation,
+                                    privateBlockArgs, privateDecls,
+                                    mlirPrivateVars, llvmPrivateVars),
+                    *taskOp)
+            .failed())
+      return llvm::make_error<PreviouslyReportedError>();
+
+    if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
+                                    llvmPrivateVars, privateDecls)))
       return llvm::make_error<PreviouslyReportedError>();
 
     // translate the body of the task:
-    builder.restoreIP(codegenIP);
     auto continuationBlockOrError = convertOmpOpRegions(
         taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
     if (failed(handleError(continuationBlockOrError, *taskOp)))
@@ -1892,7 +1917,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   SmallVector<llvm::Value *> privateReductionVariables(
       wsloopOp.getNumReductionVars());
 
-  llvm::Expected<llvm::BasicBlock *> afterAllocas = allocateAndInitPrivateVars(
+  llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
       builder, moduleTranslation, privateBlockArgs, privateDecls,
       mlirPrivateVars, llvmPrivateVars, allocaIP);
   if (handleError(afterAllocas, opInst).failed())
@@ -1911,9 +1936,15 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                                 deferredStores, isByRef)))
     return failure();
 
-  if (failed(initFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
-                                  llvmPrivateVars, privateDecls,
-                                  afterAllocas.get())))
+  if (handleError(initPrivateVars(builder, moduleTranslation, privateBlockArgs,
+                                  privateDecls, mlirPrivateVars,
+                                  llvmPrivateVars),
+                  opInst)
+          .failed())
+    return failure();
+
+  if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
+                                  llvmPrivateVars, privateDecls)))
     return failure();
 
   assert(afterAllocas.get()->getSinglePredecessor());
@@ -2064,10 +2095,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
 
   auto bodyGenCB = [&](InsertPointTy allocaIP,
                        InsertPointTy codeGenIP) -> llvm::Error {
-    llvm::Expected<llvm::BasicBlock *> afterAllocas =
-        allocateAndInitPrivateVars(builder, moduleTranslation, privateBlockArgs,
-                                   privateDecls, mlirPrivateVars,
-                                   llvmPrivateVars, allocaIP);
+    llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
+        builder, moduleTranslation, privateBlockArgs, privateDecls,
+        mlirPrivateVars, llvmPrivateVars, allocaIP);
     if (handleError(afterAllocas, *opInst).failed())
       return llvm::make_error<PreviouslyReportedError>();
 
@@ -2087,14 +2117,20 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
             deferredStores, isByRef)))
       return llvm::make_error<PreviouslyReportedError>();
 
-    if (failed(initFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
-                                    llvmPrivateVars, privateDecls,
-                                    afterAllocas.get())))
-      return llvm::make_error<PreviouslyReportedError>();
-
     assert(afterAllocas.get()->getSinglePredecessor());
     builder.restoreIP(codeGenIP);
 
+    if (handleError(initPrivateVars(builder, moduleTranslation,
+                                    privateBlockArgs, privateDecls,
+                                    mlirPrivateVars, llvmPrivateVars),
+                    *opInst)
+            .failed())
+      return llvm::make_error<PreviouslyReportedError>();
+
+    if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
+                                    llvmPrivateVars, privateDecls)))
+      return llvm::make_error<PreviouslyReportedError>();
+
     if (failed(
             initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
                               afterAllocas.get()->getSinglePredecessor(),
@@ -2142,6 +2178,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
       tempTerminator->eraseFromParent();
       builder.restoreIP(*contInsertPoint);
     }
+
     return llvm::Error::success();
   };
 
@@ -2243,14 +2280,22 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
 
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
-  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
 
-  llvm::Expected<llvm::BasicBlock *> afterAllocas = allocateAndInitPrivateVars(
+  llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
       builder, moduleTranslation, privateBlockArgs, privateDecls,
       mlirPrivateVars, llvmPrivateVars, allocaIP);
   if (handleError(afterAllocas, opInst).failed())
     return failure();
 
+  if (handleError(initPrivateVars(builder, moduleTranslation, privateBlockArgs,
+                                  privateDecls, mlirPrivateVars,
+                                  llvmPrivateVars),
+                  opInst)
+          .failed())
+    return failure();
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+
   // Generator of the canonical loop body.
   SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
   SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
@@ -4265,21 +4310,28 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
     for (mlir::Value privateVar : targetOp.getPrivateVars())
       mlirPrivateVars.push_back(privateVar);
 
-    llvm::Expected<llvm::BasicBlock *> afterAllocas =
-        allocateAndInitPrivateVars(
-            builder, moduleTranslation, privateBlockArgs, privateDecls,
-            mlirPrivateVars, llvmPrivateVars, allocaIP, &mappedPrivateVars);
+    llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
+        builder, moduleTranslation, privateBlockArgs, privateDecls,
+        mlirPrivateVars, llvmPrivateVars, allocaIP, &mappedPrivateVars);
 
     if (failed(handleError(afterAllocas, *targetOp)))
       return llvm::make_error<PreviouslyReportedError>();
 
+    builder.restoreIP(codeGenIP);
+    if (handleError(initPrivateVars(builder, moduleTranslation,
+                                    privateBlockArgs, privateDecls,
+                                    mlirPrivateVars, llvmPrivateVars,
+                                    &mappedPrivateVars),
+                    *targetOp)
+            .failed())
+      return llvm::make_error<PreviouslyReportedError>();
+
     SmallVector<Region *> privateCleanupRegions;
     llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
                     [](omp::PrivateClauseOp privatizer) {
                       return &privatizer.getDeallocRegion();
                     });
 
-    builder.restoreIP(codeGenIP);
     llvm::Expected<llvm::BasicBlock *> exitBlock = convertOmpOpRegions(
         targetRegion, "omp.target", builder, moduleTranslation);
 
diff --git a/mlir/test/Target/LLVMIR/openmp-firstprivate.mlir b/mlir/test/Target/LLVMIR/openmp-firstprivate.mlir
index 4e27640b478e43c..0358f713a324602 100644
--- a/mlir/test/Target/LLVMIR/openmp-firstprivate.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-firstprivate.mlir
@@ -59,6 +59,13 @@ llvm.func @parallel_op_firstprivate_multi_block(%arg0: !llvm.ptr) {
 // CHECK:  %[[ORIG_PTR_PTR:.*]] = getelementptr { ptr }, ptr %{{.*}}, i32 0, i32 0
 // CHECK:  %[[ORIG_PTR:.*]] = load ptr, ptr %[[ORIG_PTR_PTR]], align 8
 // CHECK:  %[[PRIV_ALLOC:.*]] = alloca float, align 4
+// CHECK-NEXT: br label %omp.region.after_alloca
+
+// CHECK: omp.region.after_alloca:
+// CHECK-NEXT:   br label %[[PAR_REG:.*]]
+
+// Check that the body of the parallel region loads from the private clone.
+// CHECK: [[PAR_REG]]:
 // CHECK:   br label %[[PRIV_BB1:.*]]
 
 // CHECK: [[PRIV_BB1]]:
@@ -95,10 +102,6 @@ llvm.func @parallel_op_firstprivate_multi_block(%arg0: !llvm.ptr) {
 // address.
 // CHECK: [[PRIV_CONT]]:
 // CHECK-NEXT:   %[[PRIV_ALLOC4:.*]] = phi ptr [ %[[PRIV_ALLOC3]], %[[PRIV_BB3]] ]
-// CHECK-NEXT:   br label %[[PAR_REG:.*]]
-
-// Check that the body of the parallel region loads from the private clone.
-// CHECK: [[PAR_REG]]:
 // CHECK:        %{{.*}} = load float, ptr %[[PRIV_ALLOC2]], align 4
 
 omp.private {type = firstprivate} @multi_block.privatizer : f32 init {
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 8a95793b96fd531..5f42240cf978ecd 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2825,18 +2825,23 @@ llvm.func @task(%arg0 : !llvm.ptr) {
 // CHECK:         %[[VAL_13:.*]] = getelementptr { ptr }, ptr %[[VAL_11]], i32 0, i32 0
 // CHECK:         %[[VAL_14:.*]] = load ptr, ptr %[[VAL_13]], align 8
 // CHECK:         %[[VAL_15:.*]] = alloca i32, align 4
+// CHECK:         br label %omp.region.after_alloca
+
+// CHECK:       omp.region.after_alloca:
+// CHECK:         br label %task.body
+
+// CHECK:       task.body:                                        ; preds = %omp.region.after_alloca
 // CHECK:         br label %omp.private.init
-// CHECK:       omp.private.init:                                 ; preds = %task.alloca
+
+// CHECK:       omp.private.init:                                 ; preds = %task.body
 // CHECK:         br label %omp.private.copy
+
 // CHECK:       omp.private.copy:                                 ; preds = %omp.private.init
 // CHECK:         %[[VAL_19:.*]] = load i32, ptr %[[VAL_14]], align 4
 // CHECK:         store i32 %[[VAL_19]], ptr %[[VAL_15]], align 4
-// CHECK:         br label %[[VAL_20:.*]]
-// CHECK:       [[VAL_20]]:
-// CHECK:         br label %task.body
-// CHECK:       task.body:                                        ; preds = %[[VAL_20]]
 // CHECK:         br label %omp.task.region
-// CHECK:       omp.task.region:                                  ; preds = %task.body
+
+// CHECK:       omp.task.region:                                  ; preds = %omp.private.copy
 // CHECK:         call void @foo(ptr %[[VAL_15]])
 // CHECK:         br label %omp.region.cont
 // CHECK:       omp.region.cont:                                  ; preds = %omp.task.region
diff --git a/mlir/test/Target/LLVMIR/openmp-private.mlir b/mlir/test/Target/LLVMIR/openmp-private.mlir
index 6913dd5427844d0..704935fa12e9736 100644
--- a/mlir/test/Target/LLVMIR/openmp-private.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-private.mlir
@@ -103,6 +103,9 @@ llvm.func @parallel_op_private_multi_block(%arg0: !llvm.ptr) {
 // CHECK:  %[[ORIG_PTR_PTR:.*]] = getelementptr { ptr }, ptr %{{.*}}, i32 0, i32 0
 // CHECK:  %[[ORIG_PTR:.*]] = load ptr, ptr %[[ORIG_PTR_PTR]], align 8
 // CHECK:  %[[PRIV_ALLOC:.*]] = alloca float, align 4
+// CHECK-NEXT:   br label %[[PAR_REG:.*]]
+
+// CHECK: [[PAR_REG]]:
 // CHECK:  br label %omp.private.init
 
 // CHECK: omp.private.init:
@@ -124,9 +127,6 @@ llvm.func @parallel_op_private_multi_block(%arg0: !llvm.ptr) {
 // address.
 // CHECK: [[PRIV_CONT]]:
 // CHECK-NEXT:   %[[PRIV_ALLOC3:.*]] = phi ptr [ %[[PRIV_ALLOC2]], %[[PRIV_BB2]] ]
-// CHECK-NEXT:   br label %[[PAR_REG:.*]]
-
-// CHECK: [[PAR_REG]]:
 // CHECK-NEXT:   br label %[[PAR_REG2:.*]]
 
 // Check that the body of the parallel region loads from the private clone.
diff --git a/mlir/test/Target/LLVMIR/openmp-wsloop-private-cond_br.mlir b/mlir/test/Target/LLVMIR/openmp-wsloop-private-cond_br.mlir
index 33737c4368a18be..ea7706365b76107 100644
--- a/mlir/test/Target/LLVMIR/openmp-wsloop-private-cond_br.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-wsloop-private-cond_br.mlir
@@ -30,9 +30,6 @@ llvm.func @wsloop_private_(%arg0: !llvm.ptr {fir.bindc_name = "y"}) attributes {
 }
 
 // CHECK:   %[[INT:.*]] = alloca i32, i64 1, align 4
-// CHECK:   br label %[[OMP_PRIVATE_INIT:.*]]
-
-// CHECK: [[OMP_PRIVATE_INIT]]:
 // CHECK:   br label %[[AFTER_ALLOC_BB:.*]]
 
 // CHECK: [[AFTER_ALLOC_BB]]:
@@ -45,5 +42,8 @@ llvm.func @wsloop_private_(%arg0: !llvm.ptr {fir.bindc_name = "y"}) attributes {
 // CHECK:   br label %[[BB3:.*]]
 
 // CHECK: [[BB3]]:
+// CHECK:   br label %[[OMP_PRIVATE_INIT:.*]]
+
+// CHECK: [[OMP_PRIVATE_INIT]]:
 // CHECK:   br label %omp_loop.preheader
 



More information about the flang-commits mailing list