[clang] 35309db - [OpenMP][OMPIRBuilder] Migrate MapCombinedInfoTy from Clang to OpenMPIRBuilder

Akash Banerjee via cfe-commits cfe-commits at lists.llvm.org
Thu May 4 08:51:17 PDT 2023


Author: Akash Banerjee
Date: 2023-05-04T16:51:06+01:00
New Revision: 35309db7dcefde20180e924df303f65a14d97d68

URL: https://github.com/llvm/llvm-project/commit/35309db7dcefde20180e924df303f65a14d97d68
DIFF: https://github.com/llvm/llvm-project/commit/35309db7dcefde20180e924df303f65a14d97d68.diff

LOG: [OpenMP][OMPIRBuilder] Migrate MapCombinedInfoTy from Clang to OpenMPIRBuilder

This patch migrates the MapCombinedInfoTy from Clang codegen to OpenMPIRBuilder.

Differential Revision: https://reviews.llvm.org/D149666

Added: 
    

Modified: 
    clang/lib/CodeGen/CGOpenMPRuntime.cpp
    llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Removed: 
    


################################################################################
diff  --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 1400afbb3a9e0..ce92aa21e0fc5 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -6831,67 +6831,30 @@ class MappableExprsHandler {
     const Expr *getMapExpr() const { return MapExpr; }
   };
 
-  /// Class that associates information with a base pointer to be passed to the
-  /// runtime library.
-  class BasePointerInfo {
-    /// The base pointer.
-    llvm::Value *Ptr = nullptr;
-    /// The base declaration that refers to this device pointer, or null if
-    /// there is none.
-    const ValueDecl *DevPtrDecl = nullptr;
-
-  public:
-    BasePointerInfo(llvm::Value *Ptr, const ValueDecl *DevPtrDecl = nullptr)
-        : Ptr(Ptr), DevPtrDecl(DevPtrDecl) {}
-    llvm::Value *operator*() const { return Ptr; }
-    const ValueDecl *getDevicePtrDecl() const { return DevPtrDecl; }
-    void setDevicePtrDecl(const ValueDecl *D) { DevPtrDecl = D; }
-  };
-
+  using MapBaseValuesArrayTy = llvm::OpenMPIRBuilder::MapValuesArrayTy;
+  using MapValuesArrayTy = llvm::OpenMPIRBuilder::MapValuesArrayTy;
+  using MapFlagsArrayTy = llvm::OpenMPIRBuilder::MapFlagsArrayTy;
+  using MapDimArrayTy = llvm::OpenMPIRBuilder::MapDimArrayTy;
+  using MapNonContiguousArrayTy =
+      llvm::OpenMPIRBuilder::MapNonContiguousArrayTy;
   using MapExprsArrayTy = SmallVector<MappingExprInfo, 4>;
-  using MapBaseValuesArrayTy = SmallVector<BasePointerInfo, 4>;
-  using MapValuesArrayTy = SmallVector<llvm::Value *, 4>;
-  using MapFlagsArrayTy = SmallVector<OpenMPOffloadMappingFlags, 4>;
-  using MapMappersArrayTy = SmallVector<const ValueDecl *, 4>;
-  using MapDimArrayTy = SmallVector<uint64_t, 4>;
-  using MapNonContiguousArrayTy = SmallVector<MapValuesArrayTy, 4>;
+  using MapValueDeclsArrayTy = SmallVector<const ValueDecl *, 4>;
 
   /// This structure contains combined information generated for mappable
   /// clauses, including base pointers, pointers, sizes, map types, user-defined
   /// mappers, and non-contiguous information.
-  struct MapCombinedInfoTy {
-    struct StructNonContiguousInfo {
-      bool IsNonContiguous = false;
-      MapDimArrayTy Dims;
-      MapNonContiguousArrayTy Offsets;
-      MapNonContiguousArrayTy Counts;
-      MapNonContiguousArrayTy Strides;
-    };
+  struct MapCombinedInfoTy : llvm::OpenMPIRBuilder::MapInfosTy {
     MapExprsArrayTy Exprs;
-    MapBaseValuesArrayTy BasePointers;
-    MapValuesArrayTy Pointers;
-    MapValuesArrayTy Sizes;
-    MapFlagsArrayTy Types;
-    MapMappersArrayTy Mappers;
-    StructNonContiguousInfo NonContigInfo;
+    MapValueDeclsArrayTy Mappers;
+    MapValueDeclsArrayTy DevicePtrDecls;
 
     /// Append arrays in \a CurInfo.
     void append(MapCombinedInfoTy &CurInfo) {
       Exprs.append(CurInfo.Exprs.begin(), CurInfo.Exprs.end());
-      BasePointers.append(CurInfo.BasePointers.begin(),
-                          CurInfo.BasePointers.end());
-      Pointers.append(CurInfo.Pointers.begin(), CurInfo.Pointers.end());
-      Sizes.append(CurInfo.Sizes.begin(), CurInfo.Sizes.end());
-      Types.append(CurInfo.Types.begin(), CurInfo.Types.end());
+      DevicePtrDecls.append(CurInfo.DevicePtrDecls.begin(),
+                            CurInfo.DevicePtrDecls.end());
       Mappers.append(CurInfo.Mappers.begin(), CurInfo.Mappers.end());
-      NonContigInfo.Dims.append(CurInfo.NonContigInfo.Dims.begin(),
-                                 CurInfo.NonContigInfo.Dims.end());
-      NonContigInfo.Offsets.append(CurInfo.NonContigInfo.Offsets.begin(),
-                                    CurInfo.NonContigInfo.Offsets.end());
-      NonContigInfo.Counts.append(CurInfo.NonContigInfo.Counts.begin(),
-                                   CurInfo.NonContigInfo.Counts.end());
-      NonContigInfo.Strides.append(CurInfo.NonContigInfo.Strides.begin(),
-                                    CurInfo.NonContigInfo.Strides.end());
+      llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo);
     }
   };
 
@@ -7638,6 +7601,7 @@ class MappableExprsHandler {
             assert(Size && "Failed to determine structure size");
             CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
             CombinedInfo.BasePointers.push_back(BP.getPointer());
+            CombinedInfo.DevicePtrDecls.push_back(nullptr);
             CombinedInfo.Pointers.push_back(LB.getPointer());
             CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
                 Size, CGF.Int64Ty, /*isSigned=*/true));
@@ -7649,6 +7613,7 @@ class MappableExprsHandler {
           }
           CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
           CombinedInfo.BasePointers.push_back(BP.getPointer());
+          CombinedInfo.DevicePtrDecls.push_back(nullptr);
           CombinedInfo.Pointers.push_back(LB.getPointer());
           Size = CGF.Builder.CreatePtrDiff(
               CGF.Int8Ty, CGF.Builder.CreateConstGEP(HB, 1).getPointer(),
@@ -7666,6 +7631,7 @@ class MappableExprsHandler {
             (Next == CE && MapType != OMPC_MAP_unknown)) {
           CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
           CombinedInfo.BasePointers.push_back(BP.getPointer());
+          CombinedInfo.DevicePtrDecls.push_back(nullptr);
           CombinedInfo.Pointers.push_back(LB.getPointer());
           CombinedInfo.Sizes.push_back(
               CGF.Builder.CreateIntCast(Size, CGF.Int64Ty, /*isSigned=*/true));
@@ -8168,7 +8134,8 @@ class MappableExprsHandler {
         [&UseDeviceDataCombinedInfo](const ValueDecl *VD, llvm::Value *Ptr,
                                      CodeGenFunction &CGF) {
           UseDeviceDataCombinedInfo.Exprs.push_back(VD);
-          UseDeviceDataCombinedInfo.BasePointers.emplace_back(Ptr, VD);
+          UseDeviceDataCombinedInfo.BasePointers.emplace_back(Ptr);
+          UseDeviceDataCombinedInfo.DevicePtrDecls.emplace_back(VD);
           UseDeviceDataCombinedInfo.Pointers.push_back(Ptr);
           UseDeviceDataCombinedInfo.Sizes.push_back(
               llvm::Constant::getNullValue(CGF.Int64Ty));
@@ -8337,8 +8304,7 @@ class MappableExprsHandler {
             assert(RelevantVD &&
                    "No relevant declaration related with device pointer??");
 
-            CurInfo.BasePointers[CurrentBasePointersIdx].setDevicePtrDecl(
-                RelevantVD);
+            CurInfo.DevicePtrDecls[CurrentBasePointersIdx] = RelevantVD;
             CurInfo.Types[CurrentBasePointersIdx] |=
                 OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
           }
@@ -8377,7 +8343,8 @@ class MappableExprsHandler {
                 OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
           }
           CurInfo.Exprs.push_back(L.VD);
-          CurInfo.BasePointers.emplace_back(BasePtr, L.VD);
+          CurInfo.BasePointers.emplace_back(BasePtr);
+          CurInfo.DevicePtrDecls.emplace_back(L.VD);
           CurInfo.Pointers.push_back(Ptr);
           CurInfo.Sizes.push_back(
               llvm::Constant::getNullValue(this->CGF.Int64Ty));
@@ -8472,6 +8439,7 @@ class MappableExprsHandler {
     CombinedInfo.Exprs.push_back(VD);
     // Base is the base of the struct
     CombinedInfo.BasePointers.push_back(PartialStruct.Base.getPointer());
+    CombinedInfo.DevicePtrDecls.push_back(nullptr);
     // Pointer is the address of the lowest element
     llvm::Value *LB = LBAddr.getPointer();
     const CXXMethodDecl *MD =
@@ -8593,6 +8561,7 @@ class MappableExprsHandler {
                                  VDLVal.getPointer(CGF));
       CombinedInfo.Exprs.push_back(VD);
       CombinedInfo.BasePointers.push_back(ThisLVal.getPointer(CGF));
+      CombinedInfo.DevicePtrDecls.push_back(nullptr);
       CombinedInfo.Pointers.push_back(ThisLValVal.getPointer(CGF));
       CombinedInfo.Sizes.push_back(
           CGF.Builder.CreateIntCast(CGF.getTypeSize(CGF.getContext().VoidPtrTy),
@@ -8619,6 +8588,7 @@ class MappableExprsHandler {
                                    VDLVal.getPointer(CGF));
         CombinedInfo.Exprs.push_back(VD);
         CombinedInfo.BasePointers.push_back(VarLVal.getPointer(CGF));
+        CombinedInfo.DevicePtrDecls.push_back(nullptr);
         CombinedInfo.Pointers.push_back(VarLValVal.getPointer(CGF));
         CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
             CGF.getTypeSize(
@@ -8630,6 +8600,7 @@ class MappableExprsHandler {
                                    VDLVal.getPointer(CGF));
         CombinedInfo.Exprs.push_back(VD);
         CombinedInfo.BasePointers.push_back(VarLVal.getPointer(CGF));
+        CombinedInfo.DevicePtrDecls.push_back(nullptr);
         CombinedInfo.Pointers.push_back(VarRVal.getScalarVal());
         CombinedInfo.Sizes.push_back(llvm::ConstantInt::get(CGF.Int64Ty, 0));
       }
@@ -8654,7 +8625,7 @@ class MappableExprsHandler {
                        OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF |
                        OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
         continue;
-      llvm::Value *BasePtr = LambdaPointers.lookup(*BasePointers[I]);
+      llvm::Value *BasePtr = LambdaPointers.lookup(BasePointers[I]);
       assert(BasePtr && "Unable to find base lambda address.");
       int TgtIdx = -1;
       for (unsigned J = I; J > 0; --J) {
@@ -8696,7 +8667,8 @@ class MappableExprsHandler {
     // pass its value.
     if (VD && (DevPointersMap.count(VD) || HasDevAddrsMap.count(VD))) {
       CombinedInfo.Exprs.push_back(VD);
-      CombinedInfo.BasePointers.emplace_back(Arg, VD);
+      CombinedInfo.BasePointers.emplace_back(Arg);
+      CombinedInfo.DevicePtrDecls.emplace_back(VD);
       CombinedInfo.Pointers.push_back(Arg);
       CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
           CGF.getTypeSize(CGF.getContext().VoidPtrTy), CGF.Int64Ty,
@@ -8938,6 +8910,7 @@ class MappableExprsHandler {
     if (CI.capturesThis()) {
       CombinedInfo.Exprs.push_back(nullptr);
       CombinedInfo.BasePointers.push_back(CV);
+      CombinedInfo.DevicePtrDecls.push_back(nullptr);
       CombinedInfo.Pointers.push_back(CV);
       const auto *PtrTy = cast<PointerType>(RI.getType().getTypePtr());
       CombinedInfo.Sizes.push_back(
@@ -8950,6 +8923,7 @@ class MappableExprsHandler {
       const VarDecl *VD = CI.getCapturedVar();
       CombinedInfo.Exprs.push_back(VD->getCanonicalDecl());
       CombinedInfo.BasePointers.push_back(CV);
+      CombinedInfo.DevicePtrDecls.push_back(nullptr);
       CombinedInfo.Pointers.push_back(CV);
       if (!RI.getType()->isAnyPointerType()) {
         // We have to signal to the runtime captures passed by value that are
@@ -8981,6 +8955,7 @@ class MappableExprsHandler {
       auto I = FirstPrivateDecls.find(VD);
       CombinedInfo.Exprs.push_back(VD->getCanonicalDecl());
       CombinedInfo.BasePointers.push_back(CV);
+      CombinedInfo.DevicePtrDecls.push_back(nullptr);
       if (I != FirstPrivateDecls.end() && ElementType->isAnyPointerType()) {
         Address PtrAddr = CGF.EmitLoadOfReference(CGF.MakeAddrLValue(
             CV, ElementType, CGF.getContext().getDeclAlign(VD),
@@ -9266,7 +9241,7 @@ static void emitOffloadingArrays(
     }
 
     for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
-      llvm::Value *BPVal = *CombinedInfo.BasePointers[I];
+      llvm::Value *BPVal = CombinedInfo.BasePointers[I];
       llvm::Value *BP = CGF.Builder.CreateConstInBoundsGEP2_32(
           llvm::ArrayType::get(CGM.VoidPtrTy, Info.NumberOfPtrs),
           Info.RTArgs.BasePointersArray, 0, I);
@@ -9277,8 +9252,7 @@ static void emitOffloadingArrays(
       CGF.Builder.CreateStore(BPVal, BPAddr);
 
       if (Info.requiresDevicePointerInfo())
-        if (const ValueDecl *DevVD =
-                CombinedInfo.BasePointers[I].getDevicePtrDecl())
+        if (const ValueDecl *DevVD = CombinedInfo.DevicePtrDecls[I])
           Info.CaptureDeviceAddrMap.try_emplace(DevVD, BPAddr);
 
       llvm::Value *PVal = CombinedInfo.Pointers[I];
@@ -9592,7 +9566,7 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D,
   // Fill up the runtime mapper handle for all components.
   for (unsigned I = 0; I < Info.BasePointers.size(); ++I) {
     llvm::Value *CurBaseArg = MapperCGF.Builder.CreateBitCast(
-        *Info.BasePointers[I], CGM.getTypes().ConvertTypeForMem(C.VoidPtrTy));
+        Info.BasePointers[I], CGM.getTypes().ConvertTypeForMem(C.VoidPtrTy));
     llvm::Value *CurBeginArg = MapperCGF.Builder.CreateBitCast(
         Info.Pointers[I], CGM.getTypes().ConvertTypeForMem(C.VoidPtrTy));
     llvm::Value *CurSizeArg = Info.Sizes[I];
@@ -10028,6 +10002,7 @@ void CGOpenMPRuntime::emitTargetCall(
       if (CI->capturesVariableArrayType()) {
         CurInfo.Exprs.push_back(nullptr);
         CurInfo.BasePointers.push_back(*CV);
+        CurInfo.DevicePtrDecls.push_back(nullptr);
         CurInfo.Pointers.push_back(*CV);
         CurInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
             CGF.getTypeSize(RI->getType()), CGF.Int64Ty, /*isSigned=*/true));

diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 43292f63bf756..d5d1646fb5a37 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1445,6 +1445,49 @@ class OpenMPIRBuilder {
     bool separateBeginEndCalls() { return SeparateBeginEndCalls; }
   };
 
+  using MapValuesArrayTy = SmallVector<Value *, 4>;
+  using MapFlagsArrayTy = SmallVector<omp::OpenMPOffloadMappingFlags, 4>;
+  using MapNamesArrayTy = SmallVector<Constant *, 4>;
+  using MapDimArrayTy = SmallVector<uint64_t, 4>;
+  using MapNonContiguousArrayTy = SmallVector<MapValuesArrayTy, 4>;
+
+  /// This structure contains combined information generated for mappable
+  /// clauses, including base pointers, pointers, sizes, map types, user-defined
+  /// mappers, and non-contiguous information.
+  struct MapInfosTy {
+    struct StructNonContiguousInfo {
+      bool IsNonContiguous = false;
+      MapDimArrayTy Dims;
+      MapNonContiguousArrayTy Offsets;
+      MapNonContiguousArrayTy Counts;
+      MapNonContiguousArrayTy Strides;
+    };
+    MapValuesArrayTy BasePointers;
+    MapValuesArrayTy Pointers;
+    MapValuesArrayTy Sizes;
+    MapFlagsArrayTy Types;
+    MapNamesArrayTy Names;
+    StructNonContiguousInfo NonContigInfo;
+
+    /// Append arrays in \a CurInfo.
+    void append(MapInfosTy &CurInfo) {
+      BasePointers.append(CurInfo.BasePointers.begin(),
+                          CurInfo.BasePointers.end());
+      Pointers.append(CurInfo.Pointers.begin(), CurInfo.Pointers.end());
+      Sizes.append(CurInfo.Sizes.begin(), CurInfo.Sizes.end());
+      Types.append(CurInfo.Types.begin(), CurInfo.Types.end());
+      Names.append(CurInfo.Names.begin(), CurInfo.Names.end());
+      NonContigInfo.Dims.append(CurInfo.NonContigInfo.Dims.begin(),
+                                CurInfo.NonContigInfo.Dims.end());
+      NonContigInfo.Offsets.append(CurInfo.NonContigInfo.Offsets.begin(),
+                                   CurInfo.NonContigInfo.Offsets.end());
+      NonContigInfo.Counts.append(CurInfo.NonContigInfo.Counts.begin(),
+                                  CurInfo.NonContigInfo.Counts.end());
+      NonContigInfo.Strides.append(CurInfo.NonContigInfo.Strides.begin(),
+                                   CurInfo.NonContigInfo.Strides.end());
+    }
+  };
+
   /// Emit the arguments to be passed to the runtime library based on the
   /// arrays of base pointers, pointers, sizes, map types, and mappers.  If
   /// ForEndCall, emit map types to be passed for the end of the region instead


        


More information about the cfe-commits mailing list