[llvm-branch-commits] [flang] [llvm] [mlir] [MLIR][OpenMP][OMPIRBuilder] Improve shared memory checks (PR #161864)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Apr 27 04:47:06 PDT 2026


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/161864

>From 17fc2c1e6db758f83f8160de6fe7b6475ad7658a Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 16 Sep 2025 14:18:39 +0100
Subject: [PATCH] [MLIR][OpenMP][OMPIRBuilder] Improve shared memory checks

This patch refines checks to decide whether to use device shared memory or
regular stack allocations. In particular, it adds support for parallel regions
residing on standalone target device functions.

The changes are:
- Shared memory is introduced for `omp.target` implicit allocations, such as
those related to privatization and mapping, as long as they are shared across
threads in a nested parallel region.
- Standalone target device functions are interpreted as being part of a Generic
kernel, since the fact that they are present in the module after filtering
means they must be reachable from a target region.
- Prevent allocations whose only shared uses inside of an `omp.parallel` region
are as part of a `private` clause from being moved to device shared memory.
---
 .../OpenMP/target-use-device-nested.f90       |  25 ++--
 .../OpenMP/threadprivate-target-device.f90    |  14 +-
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |   2 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |   5 +-
 .../Frontend/OpenMPIRBuilderTest.cpp          |  38 ++---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 133 +++++++++++++-----
 .../omptarget-constant-alloca-raise.mlir      |   2 +-
 .../LLVMIR/omptarget-parallel-llvm.mlir       |   8 +-
 .../openmp-target-private-shared-mem.mlir     |  76 ++++++++++
 .../fortran/target-generic-outlined-loops.f90 | 109 ++++++++++++++
 10 files changed, 331 insertions(+), 81 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-target-private-shared-mem.mlir
 create mode 100644 offload/test/offloading/fortran/target-generic-outlined-loops.f90

diff --git a/flang/test/Integration/OpenMP/target-use-device-nested.f90 b/flang/test/Integration/OpenMP/target-use-device-nested.f90
index 9bb4c39842731..97644383f00ed 100644
--- a/flang/test/Integration/OpenMP/target-use-device-nested.f90
+++ b/flang/test/Integration/OpenMP/target-use-device-nested.f90
@@ -7,7 +7,7 @@
 !===----------------------------------------------------------------------===!
 
 ! This tests check that target code nested inside a target data region which
-! has only use_device_ptr mapping corectly generates code on the device pass.
+! has only use_device_ptr mapping correctly generates code on the device pass.
 
 !REQUIRES: amdgpu-registered-target
 !RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-llvm -fopenmp -fopenmp-version=50 -fopenmp-is-target-device %s -o - | FileCheck %s
@@ -25,22 +25,21 @@ program main
 
 ! CHECK:         define weak_odr protected amdgpu_kernel void @__omp_offloading{{.*}}main_
 ! CHECK-NEXT:       entry:
-! CHECK-NEXT:         %[[VAL_3:.*]] = alloca ptr, align 8, addrspace(5)
-! CHECK-NEXT:         %[[ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[VAL_3]] to ptr
-! CHECK-NEXT:         store ptr %[[VAL_4:.*]], ptr %[[ASCAST]], align 8
-! CHECK-NEXT:         %[[VAL_5:.*]] = call i32 @__kmpc_target_init(ptr addrspacecast (ptr addrspace(1) @__omp_offloading_{{.*}}_kernel_environment to ptr), ptr %[[VAL_6:.*]])
-! CHECK-NEXT:         %[[VAL_7:.*]] = icmp eq i32 %[[VAL_5]], -1
-! CHECK-NEXT:         br i1 %[[VAL_7]], label %[[VAL_8:.*]], label %[[VAL_9:.*]]
-! CHECK:            user_code.entry:                                  ; preds = %[[VAL_10:.*]]
-! CHECK-NEXT:         %[[VAL_11:.*]] = load ptr, ptr %[[ASCAST]], align 8
+! CHECK-NEXT:         %[[VAL_0:.*]] = call i32 @__kmpc_target_init(ptr addrspacecast (ptr addrspace(1) @__omp_offloading_{{.*}}_kernel_environment to ptr), ptr %[[VAL_6:.*]])
+! CHECK-NEXT:         %[[VAL_1:.*]] = icmp eq i32 %[[VAL_0]], -1
+! CHECK-NEXT:         br i1 %[[VAL_1]], label %[[USER_ENTRY:.*]], label %[[EXIT:.*]]
+! CHECK:            [[USER_ENTRY]]:                                  ; preds = %entry
+! CHECK-NEXT:         %[[VAL_2:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8) 
+! CHECK-NEXT:         store ptr %[[VAL_3:.*]], ptr %[[VAL_2]], align 8
+! CHECK-NEXT:         %[[VAL_4:.*]] = load ptr, ptr %[[VAL_2]], align 8
 ! CHECK-NEXT:         br label %[[AFTER_ALLOC:.*]]
 
 ! CHECK:            [[AFTER_ALLOC]]:
-! CHECK-NEXT:         br label %[[VAL_12:.*]]
+! CHECK-NEXT:         br label %[[VAL_5:.*]]
 
-! CHECK:            [[VAL_12]]:
+! CHECK:            [[VAL_5]]:
 ! CHECK-NEXT:         br label %[[TARGET_REG_ENTRY:.*]]
 
-! CHECK:            [[TARGET_REG_ENTRY]]:                                       ; preds = %[[VAL_12]]
-! CHECK-NEXT:         call void @{{.*}}foo{{.*}}(ptr %[[VAL_11]])
+! CHECK:            [[TARGET_REG_ENTRY]]:                                       ; preds = %[[VAL_5]]
+! CHECK-NEXT:         call void @{{.*}}foo{{.*}}(ptr %[[VAL_4]])
 ! CHECK-NEXT:         br label
diff --git a/flang/test/Integration/OpenMP/threadprivate-target-device.f90 b/flang/test/Integration/OpenMP/threadprivate-target-device.f90
index 54fb332a78bb0..563cac697aa25 100644
--- a/flang/test/Integration/OpenMP/threadprivate-target-device.f90
+++ b/flang/test/Integration/OpenMP/threadprivate-target-device.f90
@@ -14,16 +14,14 @@
 ! target code in the same function.
 
 ! CHECK: define weak_odr protected amdgpu_kernel void @{{.*}}(ptr %[[ARG1:.*]], ptr %[[ARG2:.*]], ptr %{{.*}}) #{{[0-9]+}} {
-! CHECK:  %[[ALLOCA_X:.*]] = alloca ptr, align 8, addrspace(5)
-! CHECK:  %[[ASCAST_X:.*]] = addrspacecast ptr addrspace(5) %[[ALLOCA_X]] to ptr
-! CHECK:  store ptr %[[ARG1]], ptr %[[ASCAST_X]], align 8
+! CHECK:  %[[ALLOC_N:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
+! CHECK:  store ptr %[[ARG2]], ptr %[[ALLOC_N]], align 8
 
-! CHECK:  %[[ALLOCA_N:.*]] = alloca ptr, align 8, addrspace(5)
-! CHECK:  %[[ASCAST_N:.*]] = addrspacecast ptr addrspace(5) %[[ALLOCA_N]] to ptr
-! CHECK:  store ptr %[[ARG2]], ptr %[[ASCAST_N]], align 8
+! CHECK:  %[[ALLOC_X:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
+! CHECK:  store ptr %[[ARG1]], ptr %[[ALLOC_X]], align 8
 
-! CHECK:  %[[LOAD_X:.*]] = load ptr, ptr %[[ASCAST_X]], align 8
-! CHECK:  call void @bar_(ptr %[[LOAD_X]], ptr %[[ASCAST_N]])
+! CHECK:  %[[LOAD_X:.*]] = load ptr, ptr %[[ALLOC_X]], align 8
+! CHECK:  call void @bar_(ptr %[[LOAD_X]], ptr %[[ALLOC_N]])
 
 module test
   implicit none
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 00b951505268b..ba78783c7f13e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3609,7 +3609,7 @@ class OpenMPIRBuilder {
 
   using TargetGenArgAccessorsCallbackTy = function_ref<InsertPointOrErrorTy(
       Argument &Arg, Value *Input, Value *&RetVal, InsertPointTy AllocaIP,
-      InsertPointTy CodeGenIP)>;
+      InsertPointTy CodeGenIP, ArrayRef<InsertPointTy> DeallocIPs)>;
 
   /// Generator for '#omp target'
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index ad07fcfc1957f..3cf4f09f02e29 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -8903,8 +8903,9 @@ static Expected<Function *> createOutlinedFunction(
     Argument &Arg = std::get<1>(InArg);
     Value *InputCopy = nullptr;
 
-    llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-        ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP());
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = ArgAccessorFuncCB(
+        Arg, Input, InputCopy, AllocaIP, Builder.saveIP(),
+        OpenMPIRBuilder::InsertPointTy(ExitBB, ExitBB->begin()));
     if (!AfterIP)
       return AfterIP.takeError();
     Builder.restoreIP(*AfterIP);
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index eb71d0949c854..964f40469f4a9 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6452,7 +6452,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   auto SimpleArgAccessorCB =
       [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
           llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
-          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP,
+          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy> DeallocIPs) {
         IRBuilderBase::InsertPointGuard guard(Builder);
         Builder.SetCurrentDebugLocation(llvm::DebugLoc());
         if (!OMPBuilder.Config.isTargetDevice()) {
@@ -6618,7 +6619,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   auto SimpleArgAccessorCB =
       [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
           llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
-          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP,
+          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy> DeallocIPs) {
         IRBuilderBase::InsertPointGuard guard(Builder);
         Builder.SetCurrentDebugLocation(llvm::DebugLoc());
         if (!OMPBuilder.Config.isTargetDevice()) {
@@ -6820,12 +6822,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
     return Builder.saveIP();
   };
 
-  auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
-                                 OpenMPIRBuilder::InsertPointTy,
-                                 OpenMPIRBuilder::InsertPointTy CodeGenIP) {
-    Builder.restoreIP(CodeGenIP);
-    return Builder.saveIP();
-  };
+  auto SimpleArgAccessorCB =
+      [&](Argument &, Value *, Value *&, OpenMPIRBuilder::InsertPointTy,
+          OpenMPIRBuilder::InsertPointTy CodeGenIP,
+          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy>) {
+        Builder.restoreIP(CodeGenIP);
+        return Builder.saveIP();
+      };
 
   SmallVector<Value *> Inputs;
   OpenMPIRBuilder::MapInfosTy CombinedInfos;
@@ -6920,12 +6923,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
   Function *OutlinedFn = nullptr;
   SmallVector<Value *> CapturedArgs;
 
-  auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
-                                 OpenMPIRBuilder::InsertPointTy,
-                                 OpenMPIRBuilder::InsertPointTy CodeGenIP) {
-    Builder.restoreIP(CodeGenIP);
-    return Builder.saveIP();
-  };
+  auto SimpleArgAccessorCB =
+      [&](Argument &, Value *, Value *&, OpenMPIRBuilder::InsertPointTy,
+          OpenMPIRBuilder::InsertPointTy CodeGenIP,
+          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy>) {
+        Builder.restoreIP(CodeGenIP);
+        return Builder.saveIP();
+      };
 
   OpenMPIRBuilder::MapInfosTy CombinedInfos;
   auto GenMapInfoCB =
@@ -7019,7 +7023,8 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
   auto SimpleArgAccessorCB =
       [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
           llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
-          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP,
+          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy> DeallocIPs) {
         IRBuilderBase::InsertPointGuard guard(Builder);
         Builder.SetCurrentDebugLocation(llvm::DebugLoc());
         if (!OMPBuilder.Config.isTargetDevice()) {
@@ -7202,7 +7207,8 @@ TEST_F(OpenMPIRBuilderTest, DebugRecordLoc) {
   auto SimpleArgAccessorCB =
       [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
           llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
-          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP,
+          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy> DeallocIPs) {
         IRBuilderBase::InsertPointGuard guard(Builder);
         Builder.SetCurrentDebugLocation(llvm::DebugLoc());
         if (!OMPBuilder.Config.isTargetDevice()) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index e0cccbde6b442..4f82eb4706d52 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1138,9 +1138,10 @@ struct DeferredStore {
 } // namespace
 
 /// Check whether allocations for the given operation might potentially have to
-/// be done in device shared memory. That means we're compiling for a offloading
-/// target, the operation is an `omp::TargetOp` or nested inside of one and that
-/// target region represents a Generic (non-SPMD) kernel.
+/// be done in device shared memory. That means we're compiling for an
+/// offloading target, the operation is neither an `omp::TargetOp` nor nested
+/// inside of one, or it is and that target region represents a Generic
+/// (non-SPMD) kernel.
 ///
 /// This represents a necessary but not sufficient set of conditions to use
 /// device shared memory in place of regular allocas. For some variables, the
@@ -1156,7 +1157,7 @@ mightAllocInDeviceSharedMemory(Operation &op,
   if (!targetOp)
     targetOp = op.getParentOfType<omp::TargetOp>();
 
-  return targetOp &&
+  return !targetOp ||
          targetOp.getKernelExecFlags(targetOp.getInnermostCapturedOmpOp()) ==
              omp::TargetExecMode::generic;
 }
@@ -1170,19 +1171,42 @@ mightAllocInDeviceSharedMemory(Operation &op,
 /// operation that owns the specified block argument.
 static bool mustAllocPrivateVarInDeviceSharedMemory(BlockArgument value) {
   Operation *parentOp = value.getOwner()->getParentOp();
-  auto targetOp = dyn_cast<omp::TargetOp>(parentOp);
-  if (!targetOp)
-    targetOp = parentOp->getParentOfType<omp::TargetOp>();
-  assert(targetOp && "expected a parent omp.target operation");
-
+  auto moduleOp = parentOp->getParentOfType<ModuleOp>();
   for (auto *user : value.getUsers()) {
     if (auto parallelOp = dyn_cast<omp::ParallelOp>(user)) {
       if (llvm::is_contained(parallelOp.getReductionVars(), value))
         return true;
-    } else if (auto parallelOp = user->getParentOfType<omp::ParallelOp>()) {
-      if (parentOp->isProperAncestor(parallelOp))
+    } else if (auto callOp = dyn_cast<CallOpInterface>(user)) {
+      if (llvm::is_contained(callOp.getArgOperands(), value))
         return true;
     }
+
+    if (auto parallelOp = user->getParentOfType<omp::ParallelOp>()) {
+      if (parentOp->isProperAncestor(parallelOp)) {
+        // If it is used directly inside of a parallel region, skip private
+        // clause uses.
+        bool isPrivateClauseUse = false;
+        if (auto argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(user)) {
+          if (auto privateSyms = llvm::cast_or_null<ArrayAttr>(
+                  user->getAttr("private_syms"))) {
+            for (auto [var, sym] :
+                 llvm::zip_equal(argIface.getPrivateVars(), privateSyms)) {
+              if (var != value)
+                continue;
+
+              auto privateOp = cast<omp::PrivateClauseOp>(
+                  moduleOp.lookupSymbol(cast<SymbolRefAttr>(sym)));
+              if (privateOp.getCopyRegion().empty()) {
+                isPrivateClauseUse = true;
+                break;
+              }
+            }
+          }
+        }
+        if (!isPrivateClauseUse)
+          return true;
+      }
+    }
   }
 
   return false;
@@ -1206,8 +1230,8 @@ allocReductionVars(T op, ArrayRef<BlockArgument> reductionArgs,
   builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
 
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
-  bool useDeviceSharedMem =
-      isa<omp::TeamsOp>(op) && mightAllocInDeviceSharedMemory(*op, *ompBuilder);
+  bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) &&
+                            mightAllocInDeviceSharedMemory(*op, *ompBuilder);
 
   // delay creating stores until after all allocas
   deferredStores.reserve(op.getNumReductionVars());
@@ -1338,8 +1362,8 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
     return success();
 
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
-  bool useDeviceSharedMem =
-      isa<omp::TeamsOp>(op) && mightAllocInDeviceSharedMemory(*op, *ompBuilder);
+  bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) &&
+                            mightAllocInDeviceSharedMemory(*op, *ompBuilder);
 
   llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
   auto allocaIP = llvm::IRBuilderBase::InsertPoint(
@@ -1586,8 +1610,8 @@ static LogicalResult createReductionsAndCleanup(
       reductionRegions, privateReductionVariables, moduleTranslation, builder,
       "omp.reduction.cleanup");
 
-  bool useDeviceSharedMem =
-      isa<omp::TeamsOp>(op) && mightAllocInDeviceSharedMemory(*op, *ompBuilder);
+  bool useDeviceSharedMem = isa<omp::TeamsOp>(*op) &&
+                            mightAllocInDeviceSharedMemory(*op, *ompBuilder);
   if (useDeviceSharedMem) {
     for (auto [var, reductionDecl] :
          llvm::zip_equal(privateReductionVariables, reductionDecls))
@@ -1779,7 +1803,7 @@ allocatePrivateVars(T op, llvm::IRBuilderBase &builder,
 
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   bool mightUseDeviceSharedMem =
-      isa<omp::TeamsOp, omp::DistributeOp>(*op) &&
+      isa<omp::TargetOp, omp::TeamsOp, omp::DistributeOp>(*op) &&
       mightAllocInDeviceSharedMemory(*op, *ompBuilder);
   unsigned int allocaAS =
       moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
@@ -1933,7 +1957,7 @@ cleanupPrivateVars(T op, llvm::IRBuilderBase &builder,
 
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   bool mightUseDeviceSharedMem =
-      isa<omp::TeamsOp, omp::DistributeOp>(*op) &&
+      isa<omp::TargetOp, omp::TeamsOp, omp::DistributeOp>(*op) &&
       mightAllocInDeviceSharedMemory(*op, *ompBuilder);
   for (auto [privDecl, llvmPrivVar, blockArg] :
        llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.llvmVars,
@@ -6754,14 +6778,14 @@ handleDeclareTargetMapVar(MapInfoData &mapData,
 // a store of the kernel argument into this allocated memory which
 // will then be loaded from, ByCopy will use the allocated memory
 // directly.
-static llvm::IRBuilderBase::InsertPoint
-createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
-                             llvm::Value *input, llvm::Value *&retVal,
-                             llvm::IRBuilderBase &builder,
-                             llvm::OpenMPIRBuilder &ompBuilder,
-                             LLVM::ModuleTranslation &moduleTranslation,
-                             llvm::IRBuilderBase::InsertPoint allocaIP,
-                             llvm::IRBuilderBase::InsertPoint codeGenIP) {
+static llvm::IRBuilderBase::InsertPoint createDeviceArgumentAccessor(
+    omp::TargetOp targetOp, MapInfoData &mapData, llvm::Argument &arg,
+    llvm::Value *input, llvm::Value *&retVal, llvm::IRBuilderBase &builder,
+    llvm::OpenMPIRBuilder &ompBuilder,
+    LLVM::ModuleTranslation &moduleTranslation,
+    llvm::IRBuilderBase::InsertPoint allocaIP,
+    llvm::IRBuilderBase::InsertPoint codeGenIP,
+    llvm::ArrayRef<llvm::IRBuilderBase::InsertPoint> deallocIPs) {
   assert(ompBuilder.Config.isTargetDevice() &&
          "function only supported for target device codegen");
   builder.restoreIP(allocaIP);
@@ -6770,26 +6794,62 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
   LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
       ompBuilder.M.getContext());
   unsigned alignmentValue = 0;
+  BlockArgument mlirArg;
+  SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
+  cast<omp::BlockArgOpenMPOpInterface>(*targetOp).getBlockArgsPairs(
+      blockArgsPairs);
   // Find the associated MapInfoData entry for the current input
-  for (size_t i = 0; i < mapData.MapClause.size(); ++i)
+  for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
     if (mapData.OriginalValue[i] == input) {
       auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
       capture = mapOp.getMapCaptureType();
       // Get information of alignment of mapped object
       alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
           mapOp.getVarType(), ompBuilder.M.getDataLayout());
+
+      // Find the corresponding entry block argument, which can be associated to
+      // a map, use_device* or has_device* clause.
+      for (auto &[val, arg] : blockArgsPairs) {
+        if (mapOp.getResult() == val) {
+          mlirArg = arg;
+          break;
+        }
+      }
+      assert(mlirArg && "expected to find entry block argument for map clause");
       break;
     }
+  }
 
   unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
   unsigned int defaultAS =
       ompBuilder.M.getDataLayout().getProgramAddressSpace();
 
-  // Create the alloca for the argument the current point.
-  llvm::Value *v = builder.CreateAlloca(arg.getType(), allocaAS);
+  // Create the allocation for the argument.
+  llvm::Value *v = nullptr;
+  if (mightAllocInDeviceSharedMemory(*targetOp, ompBuilder) &&
+      mustAllocPrivateVarInDeviceSharedMemory(mlirArg)) {
+    // Use the beginning of the codeGenIP rather than the usual allocation point
+    // for shared memory allocations because otherwise these would be done prior
+    // to the target initialization call. Also, the exit block (where the
+    // deallocation is placed) is only executed if the initialization call
+    // succeeds.
+    builder.SetInsertPoint(codeGenIP.getBlock()->getFirstInsertionPt());
+    v = ompBuilder.createOMPAllocShared(builder, arg.getType());
+
+    // Create deallocations in all provided deallocation points and then restore
+    // the insertion point to right after the new allocations.
+    llvm::IRBuilderBase::InsertPointGuard guard(builder);
+    for (auto deallocIP : deallocIPs) {
+      builder.SetInsertPoint(deallocIP.getBlock(), deallocIP.getPoint());
+      ompBuilder.createOMPFreeShared(builder, v, arg.getType());
+    }
+  } else {
+    // Use the current point, which was previously set to allocaIP.
+    v = builder.CreateAlloca(arg.getType(), allocaAS);
 
-  if (allocaAS != defaultAS && arg.getType()->isPointerTy())
-    v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
+    if (allocaAS != defaultAS && arg.getType()->isPointerTy())
+      v = builder.CreateAddrSpaceCast(v, builder.getPtrTy(defaultAS));
+  }
 
   builder.CreateStore(&arg, v);
 
@@ -7419,7 +7479,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
 
   auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
                            llvm::Value *&retVal, InsertPointTy allocaIP,
-                           InsertPointTy codeGenIP)
+                           InsertPointTy codeGenIP,
+                           llvm::ArrayRef<InsertPointTy> deallocIPs)
       -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
     llvm::IRBuilderBase::InsertPointGuard guard(builder);
     builder.SetCurrentDebugLocation(llvm::DebugLoc());
@@ -7433,9 +7494,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
       return codeGenIP;
     }
 
-    return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
-                                        *ompBuilder, moduleTranslation,
-                                        allocaIP, codeGenIP);
+    return createDeviceArgumentAccessor(targetOp, mapData, arg, input, retVal,
+                                        builder, *ompBuilder, moduleTranslation,
+                                        allocaIP, codeGenIP, deallocIPs);
   };
 
   llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
diff --git a/mlir/test/Target/LLVMIR/omptarget-constant-alloca-raise.mlir b/mlir/test/Target/LLVMIR/omptarget-constant-alloca-raise.mlir
index 724e03885d146..3543a23f46d7d 100644
--- a/mlir/test/Target/LLVMIR/omptarget-constant-alloca-raise.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-constant-alloca-raise.mlir
@@ -39,6 +39,6 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK-NEXT: entry:
 // CHECK-NEXT:  %[[MOVED_ALLOCA1:.*]] = alloca { ptr }, align 8
 // CHECK-NEXT:  %[[MOVED_ALLOCA2:.*]] = alloca i32, i64 1, align 4
-// CHECK-NEXT:  %[[MAP_ARG_ALLOCA:.*]] = alloca ptr, align 8
 
 // CHECK: user_code.entry:                                  ; preds = %entry
+// CHECK-NEXT:  %[[MAP_ARG_ALLOC:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
index c24f5cf796468..ba745b6871e3d 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
@@ -55,15 +55,14 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK: define weak_odr protected amdgpu_kernel void @[[FUNC0:.*]](
 // CHECK-SAME: ptr %[[TMP0:.*]], ptr %[[TMP:.*]]) #{{[0-9]+}} {
 // CHECK:         %[[TMP1:.*]] = alloca [1 x ptr], align 8, addrspace(5)
-// CHECK:         %[[TMP2:.*]] = alloca ptr, align 8, addrspace(5)
-// CHECK:         %[[TMP3:.*]] = addrspacecast ptr addrspace(5) %[[TMP2]] to ptr
-// CHECK:         store ptr %[[TMP0]], ptr %[[TMP3]], align 8
 // CHECK:         %[[TMP4:.*]] = call i32 @__kmpc_target_init(ptr addrspacecast (ptr addrspace(1) @{{.*}} to ptr), ptr %[[TMP]])
 // CHECK:         %[[EXEC_USER_CODE:.*]] = icmp eq i32 %[[TMP4]], -1
 // CHECK:         br i1 %[[EXEC_USER_CODE]], label %[[USER_CODE_ENTRY:.*]], label %[[WORKER_EXIT:.*]]
 // CHECK:         %[[TMP5:.*]] = addrspacecast ptr addrspace(5) %[[TMP1]] to ptr
 // CHECK:         %[[STRUCTARG:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
-// CHECK:         %[[TMP6:.*]] = load ptr, ptr %[[TMP3]], align 8
+// CHECK:         %[[TMP2:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
+// CHECK:         store ptr %[[TMP0]], ptr %[[TMP2]], align 8
+// CHECK:         %[[TMP6:.*]] = load ptr, ptr %[[TMP2]], align 8
 // CHECK:         %[[OMP_GLOBAL_THREAD_NUM:.*]] = call i32 @__kmpc_global_thread_num(ptr addrspacecast (ptr addrspace(1) @[[GLOB1:[0-9]+]] to ptr))
 // CHECK:         %[[GEP_:.*]] = getelementptr { ptr }, ptr %[[STRUCTARG]], i32 0, i32 0
 // CHECK:         store ptr %[[TMP6]], ptr %[[GEP_]], align 8
@@ -71,6 +70,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
 // CHECK:         call void @__kmpc_parallel_60(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr @[[FUNC1_WRAPPER:.*]], ptr %[[TMP5]], i64 1, i32 0)
 // CHECK:         call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
+// CHECK:         call void @__kmpc_free_shared(ptr %[[TMP2]], i64 8)
 // CHECK:         call void @__kmpc_target_deinit()
 
 // CHECK: define internal void @[[FUNC1]](
diff --git a/mlir/test/Target/LLVMIR/openmp-target-private-shared-mem.mlir b/mlir/test/Target/LLVMIR/openmp-target-private-shared-mem.mlir
new file mode 100644
index 0000000000000..1481d8133cb0c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-private-shared-mem.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = true, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true,  dlti.dl_spec = #dlti.dl_spec<!llvm.ptr = dense<64> : vector<4xi64>, !llvm.ptr<1> = dense<64> : vector<4xi64>, !llvm.ptr<2> = dense<32> : vector<4xi64>, !llvm.ptr<3> = dense<32> : vector<4xi64>, !llvm.ptr<4> = dense<64> : vector<4xi64>, !llvm.ptr<5> = dense<32> : vector<4xi64>, !llvm.ptr<6> = dense<32> : vector<4xi64>, !llvm.ptr<7> = dense<[160, 256, 256, 32]> : vector<4xi64>, !llvm.ptr<8> = dense<[128, 128, 128, 48]> : vector<4xi64>, !llvm.ptr<9> = dense<[192, 256, 256, 32]> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, "dlti.endianness" = "little", "dlti.legal_int_widths" = array<i32: 32, 64>, "dlti.stack_alignment" = 32 : i64, "dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>} {
+  omp.private {type = private} @simple_var.privatizer : i32
+  omp.declare_reduction @simple_var.reducer : i32 init {
+  ^bb0(%arg0: i32):
+    %0 = llvm.mlir.constant(0 : i32) : i32
+    omp.yield(%0 : i32)
+  } combiner {
+  ^bb0(%arg0: i32, %arg1: i32):
+    %0 = llvm.add %arg0, %arg1 : i32
+    omp.yield(%0 : i32)
+  }
+
+  // CHECK-LABEL: declare void @device_func(ptr)
+  llvm.func @device_func(!llvm.ptr) attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>}
+  
+  // CHECK-NOT: define {{.*}} void @target_map_single_shared_mem_private
+  llvm.func @target_map_single_shared_mem_private() attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>} {
+    %0 = llvm.mlir.constant(1 : i64) : i64
+    %1 = llvm.alloca %0 x i32 : (i64) -> !llvm.ptr<5>
+    %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
+
+    // CHECK-LABEL: define {{.*}} void @__omp_offloading_{{.*}}target_map_single_shared_mem_private{{.*}}({{.*}})
+    // CHECK: call i32 @__kmpc_target_init
+    // CHECK: %[[ALLOC0:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 4)
+    // CHECK: call void @device_func(ptr %[[ALLOC0]])
+    // CHECK: call void @__kmpc_free_shared(ptr %[[ALLOC0]], i64 4)
+    // CHECK: call void @__kmpc_target_deinit
+    omp.target private(@simple_var.privatizer %2 -> %arg0 : !llvm.ptr) {
+      llvm.call @device_func(%arg0) : (!llvm.ptr) -> ()
+      omp.terminator
+    }
+
+    // CHECK-LABEL: define {{.*}} void @__omp_offloading_{{.*}}target_map_single_shared_mem_private{{.*}}({{.*}})
+    // CHECK: call i32 @__kmpc_target_init
+    // CHECK: %[[ALLOC_ARGS0:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
+    // CHECK: %[[ALLOC1:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 4)
+    // CHECK: %[[GEP0:.*]] = getelementptr { ptr }, ptr %[[ALLOC_ARGS0]], i32 0, i32 0
+    // CHECK: store ptr %[[ALLOC1]], ptr %[[GEP0]], align 8
+    // CHECK: %[[GEP1:.*]] = getelementptr inbounds [1 x ptr], ptr %[[PAR_ARGS0:.*]], i64 0, i64 0
+    // CHECK: store ptr %[[ALLOC_ARGS0]], ptr %[[GEP1]], align 8
+    // CHECK: call void @__kmpc_parallel_60(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr %[[PAR_ARGS0]], i64 1, i32 0)
+    // CHECK: call void @__kmpc_free_shared(ptr %[[ALLOC_ARGS0]], i64 8)
+    // CHECK: call void @__kmpc_free_shared(ptr %[[ALLOC1]], i64 4)
+    // CHECK: call void @__kmpc_target_deinit
+    omp.target private(@simple_var.privatizer %2 -> %arg0 : !llvm.ptr) {
+      omp.parallel reduction(@simple_var.reducer %arg0 -> %arg1 : !llvm.ptr) {
+        %3 = llvm.load %arg1 : !llvm.ptr -> i32
+        omp.terminator
+      }
+      omp.terminator
+    }
+
+    // CHECK-LABEL: define {{.*}} void @__omp_offloading_{{.*}}target_map_single_shared_mem_private{{.*}}({{.*}})
+    // CHECK: call i32 @__kmpc_target_init
+    // CHECK: %[[ALLOC_ARGS1:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
+    // CHECK: %[[ALLOC2:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 4)
+    // CHECK: %[[GEP2:.*]] = getelementptr { ptr }, ptr %[[ALLOC_ARGS1]], i32 0, i32 0
+    // CHECK: store ptr %[[ALLOC2]], ptr %[[GEP2]], align 8
+    // CHECK: %[[GEP3:.*]] = getelementptr inbounds [1 x ptr], ptr %[[PAR_ARGS1:.*]], i64 0, i64 0
+    // CHECK: store ptr %[[ALLOC_ARGS1]], ptr %[[GEP3]], align 8
+    // CHECK: call void @__kmpc_parallel_60(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr %[[PAR_ARGS1]], i64 1, i32 0)
+    // CHECK: call void @__kmpc_free_shared(ptr %[[ALLOC_ARGS1]], i64 8)
+    // CHECK: call void @__kmpc_free_shared(ptr %[[ALLOC2]], i64 4)
+    // CHECK: call void @__kmpc_target_deinit
+    omp.target private(@simple_var.privatizer %2 -> %arg0 : !llvm.ptr) {
+      omp.parallel {
+        %4 = llvm.load %arg0 : !llvm.ptr -> i32
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
diff --git a/offload/test/offloading/fortran/target-generic-outlined-loops.f90 b/offload/test/offloading/fortran/target-generic-outlined-loops.f90
new file mode 100644
index 0000000000000..594809027e115
--- /dev/null
+++ b/offload/test/offloading/fortran/target-generic-outlined-loops.f90
@@ -0,0 +1,109 @@
+! Offloading test for generic target regions containing different kinds of
+! loop constructs inside, moving parallel regions into a separate subroutine.
+! REQUIRES: flang, amdgpu
+
+! RUN: %libomptarget-compile-fortran-run-and-check-generic
+subroutine parallel_loop(n, counter)
+  implicit none
+  integer, intent(in) :: n
+  integer, intent(inout) :: counter
+  integer :: i
+
+  !$omp parallel do reduction(+:counter)
+  do i=1, n
+    counter = counter + 1
+  end do
+end subroutine
+
+program main
+  integer :: i1, i2, n1, n2, counter
+
+  n1 = 100
+  n2 = 50
+
+  counter = 0
+  !$omp target map(tofrom:counter)
+    !$omp teams distribute reduction(+:counter)
+    do i1=1, n1
+      counter = counter + 1
+    end do
+  !$omp end target
+
+  ! CHECK: 1 100
+  print '(I2" "I0)', 1, counter
+
+  counter = 0
+  !$omp target map(tofrom:counter)
+    call parallel_loop(n1, counter)
+    call parallel_loop(n1, counter)
+  !$omp end target
+
+  ! CHECK: 2 200
+  print '(I2" "I0)', 2, counter
+
+  counter = 0
+  !$omp target map(tofrom:counter)
+    counter = counter + 1
+    call parallel_loop(n1, counter)
+    counter = counter + 1
+    call parallel_loop(n1, counter)
+    counter = counter + 1
+  !$omp end target
+
+  ! CHECK: 3 203
+  print '(I2" "I0)', 3, counter
+
+  counter = 0
+  !$omp target map(tofrom: counter)
+    counter = counter + 1
+    call parallel_loop(n1, counter)
+    counter = counter + 1
+  !$omp end target
+
+  ! CHECK: 4 102
+  print '(I2" "I0)', 4, counter
+
+
+  counter = 0
+  !$omp target teams distribute reduction(+:counter)
+  do i1=1, n1
+    call parallel_loop(n2, counter)
+  end do
+
+  ! CHECK: 5 5000
+  print '(I2" "I0)', 5, counter
+
+  counter = 0
+  !$omp target teams distribute reduction(+:counter)
+  do i1=1, n1
+    counter = counter + 1
+    call parallel_loop(n2, counter)
+    counter = counter + 1
+  end do
+
+  ! CHECK: 6 5200
+  print '(I2" "I0)', 6, counter
+
+  counter = 0
+  !$omp target teams distribute reduction(+:counter)
+  do i1=1, n1
+    call parallel_loop(n2, counter)
+    call parallel_loop(n2, counter)
+  end do
+
+  ! CHECK: 7 10000
+  print '(I2" "I0)', 7, counter
+
+  counter = 0
+  !$omp target teams distribute reduction(+:counter)
+  do i1=1, n1
+    counter = counter + 1
+    call parallel_loop(n2, counter)
+    counter = counter + 1
+    call parallel_loop(n2, counter)
+    counter = counter + 1
+  end do
+
+  ! CHECK: 8 10300
+  print '(I2" "I0)', 8, counter
+end program



More information about the llvm-branch-commits mailing list