[clang] 227012c - [OpenMP] Migrate device code privatisation from Clang CodeGen to OMPIRBuilder

Akash Banerjee via cfe-commits cfe-commits at lists.llvm.org
Wed Jul 12 04:03:38 PDT 2023


Author: Akash Banerjee
Date: 2023-07-12T12:03:28+01:00
New Revision: 227012cbd71f7f55c9f28f561a069628a964a97a

URL: https://github.com/llvm/llvm-project/commit/227012cbd71f7f55c9f28f561a069628a964a97a
DIFF: https://github.com/llvm/llvm-project/commit/227012cbd71f7f55c9f28f561a069628a964a97a.diff

LOG: [OpenMP] Migrate device code privatisation from Clang CodeGen to OMPIRBuilder

This patch migrates the UseDevicePtr and UseDeviceAddr clause related code for handling privatisation from Clang codegen to the OMPIRBuilder

Depends on D150860

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D152554

Added: 
    

Modified: 
    clang/lib/CodeGen/CGOpenMPRuntime.cpp
    clang/lib/CodeGen/CGOpenMPRuntime.h
    clang/lib/CodeGen/CGStmtOpenMP.cpp
    clang/lib/CodeGen/CodeGenFunction.h
    clang/test/OpenMP/target_data_use_device_ptr_codegen.cpp
    llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
    llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index e5952bf436df93..a52ec8909b12ac 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -6819,6 +6819,7 @@ class MappableExprsHandler {
     const Expr *getMapExpr() const { return MapExpr; }
   };
 
+  using DeviceInfoTy = llvm::OpenMPIRBuilder::DeviceInfoTy;
   using MapBaseValuesArrayTy = llvm::OpenMPIRBuilder::MapValuesArrayTy;
   using MapValuesArrayTy = llvm::OpenMPIRBuilder::MapValuesArrayTy;
   using MapFlagsArrayTy = llvm::OpenMPIRBuilder::MapFlagsArrayTy;
@@ -7589,6 +7590,7 @@ class MappableExprsHandler {
             CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
             CombinedInfo.BasePointers.push_back(BP.getPointer());
             CombinedInfo.DevicePtrDecls.push_back(nullptr);
+            CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
             CombinedInfo.Pointers.push_back(LB.getPointer());
             CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
                 Size, CGF.Int64Ty, /*isSigned=*/true));
@@ -7601,6 +7603,7 @@ class MappableExprsHandler {
           CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
           CombinedInfo.BasePointers.push_back(BP.getPointer());
           CombinedInfo.DevicePtrDecls.push_back(nullptr);
+          CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
           CombinedInfo.Pointers.push_back(LB.getPointer());
           Size = CGF.Builder.CreatePtrDiff(
               CGF.Int8Ty, CGF.Builder.CreateConstGEP(HB, 1).getPointer(),
@@ -7619,6 +7622,7 @@ class MappableExprsHandler {
           CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
           CombinedInfo.BasePointers.push_back(BP.getPointer());
           CombinedInfo.DevicePtrDecls.push_back(nullptr);
+          CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
           CombinedInfo.Pointers.push_back(LB.getPointer());
           CombinedInfo.Sizes.push_back(
               CGF.Builder.CreateIntCast(Size, CGF.Int64Ty, /*isSigned=*/true));
@@ -8119,10 +8123,12 @@ class MappableExprsHandler {
 
     auto &&UseDeviceDataCombinedInfoGen =
         [&UseDeviceDataCombinedInfo](const ValueDecl *VD, llvm::Value *Ptr,
-                                     CodeGenFunction &CGF) {
+                                     CodeGenFunction &CGF, bool IsDevAddr) {
           UseDeviceDataCombinedInfo.Exprs.push_back(VD);
           UseDeviceDataCombinedInfo.BasePointers.emplace_back(Ptr);
           UseDeviceDataCombinedInfo.DevicePtrDecls.emplace_back(VD);
+          UseDeviceDataCombinedInfo.DevicePointers.emplace_back(
+              IsDevAddr ? DeviceInfoTy::Address : DeviceInfoTy::Pointer);
           UseDeviceDataCombinedInfo.Pointers.push_back(Ptr);
           UseDeviceDataCombinedInfo.Sizes.push_back(
               llvm::Constant::getNullValue(CGF.Int64Ty));
@@ -8162,7 +8168,7 @@ class MappableExprsHandler {
             } else {
               Ptr = CGF.EmitLoadOfScalar(CGF.EmitLValue(IE), IE->getExprLoc());
             }
-            UseDeviceDataCombinedInfoGen(VD, Ptr, CGF);
+            UseDeviceDataCombinedInfoGen(VD, Ptr, CGF, IsDevAddr);
           }
         };
 
@@ -8189,6 +8195,7 @@ class MappableExprsHandler {
           // item.
           if (CI != Data.end()) {
             if (IsDevAddr) {
+              CI->ForDeviceAddr = IsDevAddr;
               CI->ReturnDevicePointer = true;
               Found = true;
               break;
@@ -8201,6 +8208,7 @@ class MappableExprsHandler {
                   PrevCI == CI->Components.rend() ||
                   isa<MemberExpr>(PrevCI->getAssociatedExpression()) || !VarD ||
                   VarD->hasLocalStorage()) {
+                CI->ForDeviceAddr = IsDevAddr;
                 CI->ReturnDevicePointer = true;
                 Found = true;
                 break;
@@ -8292,6 +8300,8 @@ class MappableExprsHandler {
                    "No relevant declaration related with device pointer??");
 
             CurInfo.DevicePtrDecls[CurrentBasePointersIdx] = RelevantVD;
+            CurInfo.DevicePointers[CurrentBasePointersIdx] =
+                L.ForDeviceAddr ? DeviceInfoTy::Address : DeviceInfoTy::Pointer;
             CurInfo.Types[CurrentBasePointersIdx] |=
                 OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
           }
@@ -8332,6 +8342,8 @@ class MappableExprsHandler {
           CurInfo.Exprs.push_back(L.VD);
           CurInfo.BasePointers.emplace_back(BasePtr);
           CurInfo.DevicePtrDecls.emplace_back(L.VD);
+          CurInfo.DevicePointers.emplace_back(
+              L.ForDeviceAddr ? DeviceInfoTy::Address : DeviceInfoTy::Pointer);
           CurInfo.Pointers.push_back(Ptr);
           CurInfo.Sizes.push_back(
               llvm::Constant::getNullValue(this->CGF.Int64Ty));
@@ -8427,6 +8439,7 @@ class MappableExprsHandler {
     // Base is the base of the struct
     CombinedInfo.BasePointers.push_back(PartialStruct.Base.getPointer());
     CombinedInfo.DevicePtrDecls.push_back(nullptr);
+    CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
     // Pointer is the address of the lowest element
     llvm::Value *LB = LBAddr.getPointer();
     const CXXMethodDecl *MD =
@@ -8549,6 +8562,7 @@ class MappableExprsHandler {
       CombinedInfo.Exprs.push_back(VD);
       CombinedInfo.BasePointers.push_back(ThisLVal.getPointer(CGF));
       CombinedInfo.DevicePtrDecls.push_back(nullptr);
+      CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
       CombinedInfo.Pointers.push_back(ThisLValVal.getPointer(CGF));
       CombinedInfo.Sizes.push_back(
           CGF.Builder.CreateIntCast(CGF.getTypeSize(CGF.getContext().VoidPtrTy),
@@ -8576,6 +8590,7 @@ class MappableExprsHandler {
         CombinedInfo.Exprs.push_back(VD);
         CombinedInfo.BasePointers.push_back(VarLVal.getPointer(CGF));
         CombinedInfo.DevicePtrDecls.push_back(nullptr);
+        CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
         CombinedInfo.Pointers.push_back(VarLValVal.getPointer(CGF));
         CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
             CGF.getTypeSize(
@@ -8588,6 +8603,7 @@ class MappableExprsHandler {
         CombinedInfo.Exprs.push_back(VD);
         CombinedInfo.BasePointers.push_back(VarLVal.getPointer(CGF));
         CombinedInfo.DevicePtrDecls.push_back(nullptr);
+        CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
         CombinedInfo.Pointers.push_back(VarRVal.getScalarVal());
         CombinedInfo.Sizes.push_back(llvm::ConstantInt::get(CGF.Int64Ty, 0));
       }
@@ -8656,6 +8672,7 @@ class MappableExprsHandler {
       CombinedInfo.Exprs.push_back(VD);
       CombinedInfo.BasePointers.emplace_back(Arg);
       CombinedInfo.DevicePtrDecls.emplace_back(VD);
+      CombinedInfo.DevicePointers.emplace_back(DeviceInfoTy::Pointer);
       CombinedInfo.Pointers.push_back(Arg);
       CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
           CGF.getTypeSize(CGF.getContext().VoidPtrTy), CGF.Int64Ty,
@@ -8896,6 +8913,7 @@ class MappableExprsHandler {
       CombinedInfo.Exprs.push_back(nullptr);
       CombinedInfo.BasePointers.push_back(CV);
       CombinedInfo.DevicePtrDecls.push_back(nullptr);
+      CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
       CombinedInfo.Pointers.push_back(CV);
       const auto *PtrTy = cast<PointerType>(RI.getType().getTypePtr());
       CombinedInfo.Sizes.push_back(
@@ -8909,6 +8927,7 @@ class MappableExprsHandler {
       CombinedInfo.Exprs.push_back(VD->getCanonicalDecl());
       CombinedInfo.BasePointers.push_back(CV);
       CombinedInfo.DevicePtrDecls.push_back(nullptr);
+      CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
       CombinedInfo.Pointers.push_back(CV);
       if (!RI.getType()->isAnyPointerType()) {
         // We have to signal to the runtime captures passed by value that are
@@ -8941,6 +8960,7 @@ class MappableExprsHandler {
       CombinedInfo.Exprs.push_back(VD->getCanonicalDecl());
       CombinedInfo.BasePointers.push_back(CV);
       CombinedInfo.DevicePtrDecls.push_back(nullptr);
+      CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
       if (I != FirstPrivateDecls.end() && ElementType->isAnyPointerType()) {
         Address PtrAddr = CGF.EmitLoadOfReference(CGF.MakeAddrLValue(
             CV, ElementType, CGF.getContext().getDeclAlign(VD),
@@ -9022,7 +9042,6 @@ static void emitOffloadingArrays(
     CGOpenMPRuntime::TargetDataInfo &Info, llvm::OpenMPIRBuilder &OMPBuilder,
     bool IsNonContiguous = false) {
   CodeGenModule &CGM = CGF.CGM;
-  ASTContext &Ctx = CGF.getContext();
 
   // Reset the array information.
   Info.clearArrayInfo();
@@ -9044,11 +9063,9 @@ static void emitOffloadingArrays(
                     FillInfoMap);
   }
 
-  auto DeviceAddrCB = [&](unsigned int I, llvm::Value *BP, llvm::Value *BPVal) {
+  auto DeviceAddrCB = [&](unsigned int I, llvm::Value *NewDecl) {
     if (const ValueDecl *DevVD = CombinedInfo.DevicePtrDecls[I]) {
-      Address BPAddr(BP, BPVal->getType(),
-                     Ctx.getTypeAlignInChars(Ctx.VoidPtrTy));
-      Info.CaptureDeviceAddrMap.try_emplace(DevVD, BPAddr);
+      Info.CaptureDeviceAddrMap.try_emplace(DevVD, NewDecl);
     }
   };
 
@@ -9663,6 +9680,8 @@ static void emitTargetCallKernelLaunch(
       CurInfo.Exprs.push_back(nullptr);
       CurInfo.BasePointers.push_back(*CV);
       CurInfo.DevicePtrDecls.push_back(nullptr);
+      CurInfo.DevicePointers.push_back(
+          MappableExprsHandler::DeviceInfoTy::None);
       CurInfo.Pointers.push_back(*CV);
       CurInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
           CGF.getTypeSize(RI->getType()), CGF.Int64Ty, /*isSigned=*/true));
@@ -10463,12 +10482,9 @@ void CGOpenMPRuntime::emitTargetDataCalls(
                          CGF.Builder.GetInsertPoint());
   };
 
-  auto DeviceAddrCB = [&](unsigned int I, llvm::Value *BP, llvm::Value *BPVal) {
+  auto DeviceAddrCB = [&](unsigned int I, llvm::Value *NewDecl) {
     if (const ValueDecl *DevVD = CombinedInfo.DevicePtrDecls[I]) {
-      ASTContext &Ctx = CGF.getContext();
-      Address BPAddr(BP, BPVal->getType(),
-                     Ctx.getTypeAlignInChars(Ctx.VoidPtrTy));
-      Info.CaptureDeviceAddrMap.try_emplace(DevVD, BPAddr);
+      Info.CaptureDeviceAddrMap.try_emplace(DevVD, NewDecl);
     }
   };
 

diff  --git a/clang/lib/CodeGen/CGOpenMPRuntime.h b/clang/lib/CodeGen/CGOpenMPRuntime.h
index ebdbf4c8c77bf0..2ee2a39ba538c4 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.h
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.h
@@ -1458,9 +1458,9 @@ class CGOpenMPRuntime {
                             bool SeparateBeginEndCalls)
         : llvm::OpenMPIRBuilder::TargetDataInfo(RequiresDevicePointerInfo,
                                                 SeparateBeginEndCalls) {}
-    /// Map between the a declaration of a capture and the corresponding base
-    /// pointer address where the runtime returns the device pointers.
-    llvm::DenseMap<const ValueDecl *, Address> CaptureDeviceAddrMap;
+    /// Map between the a declaration of a capture and the corresponding new
+    /// llvm address where the runtime returns the device pointers.
+    llvm::DenseMap<const ValueDecl *, llvm::Value *> CaptureDeviceAddrMap;
   };
 
   /// Emit the target data mapping code associated with \a D.

diff  --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index 74a04996a3d025..0c87b70e82cfbf 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -7174,14 +7174,13 @@ CodeGenFunction::getOMPCancelDestination(OpenMPDirectiveKind Kind) {
 
 void CodeGenFunction::EmitOMPUseDevicePtrClause(
     const OMPUseDevicePtrClause &C, OMPPrivateScope &PrivateScope,
-    const llvm::DenseMap<const ValueDecl *, Address> &CaptureDeviceAddrMap) {
-  auto OrigVarIt = C.varlist_begin();
-  auto InitIt = C.inits().begin();
-  for (const Expr *PvtVarIt : C.private_copies()) {
-    const auto *OrigVD =
-        cast<VarDecl>(cast<DeclRefExpr>(*OrigVarIt)->getDecl());
-    const auto *InitVD = cast<VarDecl>(cast<DeclRefExpr>(*InitIt)->getDecl());
-    const auto *PvtVD = cast<VarDecl>(cast<DeclRefExpr>(PvtVarIt)->getDecl());
+    const llvm::DenseMap<const ValueDecl *, llvm::Value *>
+        CaptureDeviceAddrMap) {
+  llvm::SmallDenseSet<CanonicalDeclPtr<const Decl>, 4> Processed;
+  for (const Expr *OrigVarIt : C.varlists()) {
+    const auto *OrigVD = cast<VarDecl>(cast<DeclRefExpr>(OrigVarIt)->getDecl());
+    if (!Processed.insert(OrigVD).second)
+      continue;
 
     // In order to identify the right initializer we need to match the
     // declaration used by the mapping logic. In some cases we may get
@@ -7202,32 +7201,16 @@ void CodeGenFunction::EmitOMPUseDevicePtrClause(
     if (InitAddrIt == CaptureDeviceAddrMap.end())
       continue;
 
-    // Initialize the temporary initialization variable with the address
-    // we get from the runtime library. We have to cast the source address
-    // because it is always a void *. References are materialized in the
-    // privatization scope, so the initialization here disregards the fact
-    // the original variable is a reference.
     llvm::Type *Ty = ConvertTypeForMem(OrigVD->getType().getNonReferenceType());
-    Address InitAddr = InitAddrIt->second.withElementType(Ty);
-    setAddrOfLocalVar(InitVD, InitAddr);
-
-    // Emit private declaration, it will be initialized by the value we
-    // declaration we just added to the local declarations map.
-    EmitDecl(*PvtVD);
-
-    // The initialization variables reached its purpose in the emission
-    // of the previous declaration, so we don't need it anymore.
-    LocalDeclMap.erase(InitVD);
 
     // Return the address of the private variable.
-    bool IsRegistered =
-        PrivateScope.addPrivate(OrigVD, GetAddrOfLocalVar(PvtVD));
+    bool IsRegistered = PrivateScope.addPrivate(
+        OrigVD,
+        Address(InitAddrIt->second, Ty,
+                getContext().getTypeAlignInChars(getContext().VoidPtrTy)));
     assert(IsRegistered && "firstprivate var already registered as private");
     // Silence the warning about unused variable.
     (void)IsRegistered;
-
-    ++OrigVarIt;
-    ++InitIt;
   }
 }
 
@@ -7242,7 +7225,8 @@ static const VarDecl *getBaseDecl(const Expr *Ref) {
 
 void CodeGenFunction::EmitOMPUseDeviceAddrClause(
     const OMPUseDeviceAddrClause &C, OMPPrivateScope &PrivateScope,
-    const llvm::DenseMap<const ValueDecl *, Address> &CaptureDeviceAddrMap) {
+    const llvm::DenseMap<const ValueDecl *, llvm::Value *>
+        CaptureDeviceAddrMap) {
   llvm::SmallDenseSet<CanonicalDeclPtr<const Decl>, 4> Processed;
   for (const Expr *Ref : C.varlists()) {
     const VarDecl *OrigVD = getBaseDecl(Ref);
@@ -7267,7 +7251,11 @@ void CodeGenFunction::EmitOMPUseDeviceAddrClause(
     if (InitAddrIt == CaptureDeviceAddrMap.end())
       continue;
 
-    Address PrivAddr = InitAddrIt->getSecond();
+    llvm::Type *Ty = ConvertTypeForMem(OrigVD->getType().getNonReferenceType());
+
+    Address PrivAddr =
+        Address(InitAddrIt->second, Ty,
+                getContext().getTypeAlignInChars(getContext().VoidPtrTy));
     // For declrefs and variable length array need to load the pointer for
     // correct mapping, since the pointer to the data was passed to the runtime.
     if (isa<DeclRefExpr>(Ref->IgnoreParenImpCasts()) ||

diff  --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 1debbbf57fe6b9..bcd4ef4520740b 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3403,10 +3403,12 @@ class CodeGenFunction : public CodeGenTypeCache {
                             OMPPrivateScope &PrivateScope);
   void EmitOMPUseDevicePtrClause(
       const OMPUseDevicePtrClause &C, OMPPrivateScope &PrivateScope,
-      const llvm::DenseMap<const ValueDecl *, Address> &CaptureDeviceAddrMap);
+      const llvm::DenseMap<const ValueDecl *, llvm::Value *>
+          CaptureDeviceAddrMap);
   void EmitOMPUseDeviceAddrClause(
       const OMPUseDeviceAddrClause &C, OMPPrivateScope &PrivateScope,
-      const llvm::DenseMap<const ValueDecl *, Address> &CaptureDeviceAddrMap);
+      const llvm::DenseMap<const ValueDecl *, llvm::Value *>
+          CaptureDeviceAddrMap);
   /// Emit code for copyin clause in \a D directive. The next code is
   /// generated at the start of outlined functions for directives:
   /// \code

diff  --git a/clang/test/OpenMP/target_data_use_device_ptr_codegen.cpp b/clang/test/OpenMP/target_data_use_device_ptr_codegen.cpp
index 0e9dbd39fd6418..5570c815716101 100644
--- a/clang/test/OpenMP/target_data_use_device_ptr_codegen.cpp
+++ b/clang/test/OpenMP/target_data_use_device_ptr_codegen.cpp
@@ -411,11 +411,11 @@ struct ST {
     // CK2:     [[BP2:%.+]] = getelementptr inbounds [3 x ptr], ptr %{{.+}}, i32 0, i32 2
     // CK2:     store ptr [[RVAL2:%.+]], ptr [[BP2]],
     // CK2:     call void @__tgt_target_data_begin{{.+}}[[MTYPE03]]
+    // CK2:     [[VAL1:%.+]] = load ptr, ptr [[BP1]],
+    // CK2:     store ptr [[VAL1]], ptr [[PVT1:%.+]],
     // CK2:     [[VAL2:%.+]] = load ptr, ptr [[BP2]],
     // CK2:     store ptr [[VAL2]], ptr [[PVT2:%.+]],
     // CK2:     store ptr [[PVT2]], ptr [[_PVT2:%.+]],
-    // CK2:     [[VAL1:%.+]] = load ptr, ptr [[BP1]],
-    // CK2:     store ptr [[VAL1]], ptr [[PVT1:%.+]],
     // CK2:     store ptr [[PVT1]], ptr [[_PVT1:%.+]],
     // CK2:     [[TT2:%.+]] = load ptr, ptr [[_PVT2]],
     // CK2:     [[_TT2:%.+]] = load ptr, ptr [[TT2]],

diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 3cf9ac84bf0bfe..cca04ba74c3160 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1610,6 +1610,9 @@ class OpenMPIRBuilder {
   public:
     TargetDataRTArgs RTArgs;
 
+    SmallMapVector<const Value *, std::pair<Value *, Value *>, 4>
+        DevicePtrInfoMap;
+
     /// Indicate whether any user-defined mapper exists.
     bool HasMapper = false;
     /// The total number of pointers passed to the runtime library.
@@ -1636,7 +1639,9 @@ class OpenMPIRBuilder {
     bool separateBeginEndCalls() { return SeparateBeginEndCalls; }
   };
 
+  enum class DeviceInfoTy { None, Pointer, Address };
   using MapValuesArrayTy = SmallVector<Value *, 4>;
+  using MapDeviceInfoArrayTy = SmallVector<DeviceInfoTy, 4>;
   using MapFlagsArrayTy = SmallVector<omp::OpenMPOffloadMappingFlags, 4>;
   using MapNamesArrayTy = SmallVector<Constant *, 4>;
   using MapDimArrayTy = SmallVector<uint64_t, 4>;
@@ -1655,6 +1660,7 @@ class OpenMPIRBuilder {
     };
     MapValuesArrayTy BasePointers;
     MapValuesArrayTy Pointers;
+    MapDeviceInfoArrayTy DevicePointers;
     MapValuesArrayTy Sizes;
     MapFlagsArrayTy Types;
     MapNamesArrayTy Names;
@@ -1665,6 +1671,8 @@ class OpenMPIRBuilder {
       BasePointers.append(CurInfo.BasePointers.begin(),
                           CurInfo.BasePointers.end());
       Pointers.append(CurInfo.Pointers.begin(), CurInfo.Pointers.end());
+      DevicePointers.append(CurInfo.DevicePointers.begin(),
+                            CurInfo.DevicePointers.end());
       Sizes.append(CurInfo.Sizes.begin(), CurInfo.Sizes.end());
       Types.append(CurInfo.Types.begin(), CurInfo.Types.end());
       Names.append(CurInfo.Names.begin(), CurInfo.Names.end());
@@ -1723,7 +1731,7 @@ class OpenMPIRBuilder {
   void emitOffloadingArrays(
       InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
       TargetDataInfo &Info, bool IsNonContiguous = false,
-      function_ref<void(unsigned int, Value *, Value *)> DeviceAddrCB = nullptr,
+      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
       function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
 
   /// Creates offloading entry for the provided entry ID \a ID, address \a
@@ -2110,7 +2118,7 @@ class OpenMPIRBuilder {
       function_ref<InsertPointTy(InsertPointTy CodeGenIP,
                                  BodyGenTy BodyGenType)>
           BodyGenCB = nullptr,
-      function_ref<void(unsigned int, Value *, Value *)> DeviceAddrCB = nullptr,
+      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
       function_ref<Value *(unsigned int)> CustomMapperCB = nullptr,
       Value *SrcLocInfo = nullptr);
 

diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 494cd1a0f2f016..1b2837226aca57 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4166,7 +4166,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
     omp::RuntimeFunction *MapperFunc,
     function_ref<InsertPointTy(InsertPointTy CodeGenIP, BodyGenTy BodyGenType)>
         BodyGenCB,
-    function_ref<void(unsigned int, Value *, Value *)> DeviceAddrCB,
+    function_ref<void(unsigned int, Value *)> DeviceAddrCB,
     function_ref<Value *(unsigned int)> CustomMapperCB, Value *SrcLocInfo) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
@@ -4213,6 +4213,14 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
 
       Builder.CreateCall(BeginMapperFunc, OffloadingArgs);
 
+      for (auto DeviceMap : Info.DevicePtrInfoMap) {
+        if (isa<AllocaInst>(DeviceMap.second.second)) {
+          auto *LI =
+              Builder.CreateLoad(Builder.getPtrTy(), DeviceMap.second.first);
+          Builder.CreateStore(LI, DeviceMap.second.second);
+        }
+      }
+
       // If device pointer privatization is required, emit the body of the
       // region here. It will have to be duplicated: with and without
       // privatization.
@@ -4628,7 +4636,7 @@ void OpenMPIRBuilder::emitNonContiguousDescriptor(InsertPointTy AllocaIP,
 void OpenMPIRBuilder::emitOffloadingArrays(
     InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
     TargetDataInfo &Info, bool IsNonContiguous,
-    function_ref<void(unsigned int, Value *, Value *)> DeviceAddrCB,
+    function_ref<void(unsigned int, Value *)> DeviceAddrCB,
     function_ref<Value *(unsigned int)> CustomMapperCB) {
 
   // Reset the array information.
@@ -4766,9 +4774,21 @@ void OpenMPIRBuilder::emitOffloadingArrays(
         BPVal, BP, M.getDataLayout().getPrefTypeAlign(Builder.getInt8PtrTy()));
 
     if (Info.requiresDevicePointerInfo()) {
-      assert(DeviceAddrCB &&
-             "DeviceAddrCB missing for DevicePtr code generation");
-      DeviceAddrCB(I, BP, BPVal);
+      if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
+        CodeGenIP = Builder.saveIP();
+        Builder.restoreIP(AllocaIP);
+        Info.DevicePtrInfoMap[BPVal] = {
+            BP, Builder.CreateAlloca(Builder.getPtrTy())};
+        Builder.restoreIP(CodeGenIP);
+        assert(DeviceAddrCB &&
+               "DeviceAddrCB missing for DevicePtr code generation");
+        DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
+      } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
+        Info.DevicePtrInfoMap[BPVal] = {BP, BP};
+        assert(DeviceAddrCB &&
+               "DeviceAddrCB missing for DevicePtr code generation");
+        DeviceAddrCB(I, BP);
+      }
     }
 
     Value *PVal = CombinedInfo.Pointers[I];


        


More information about the cfe-commits mailing list