[llvm-branch-commits] [llvm] [mlir] [OpenMP]Update use_device_clause lowering (PR #101707)

Akash Banerjee via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Aug 19 08:04:05 PDT 2024


https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/101707

>From 3a2afe783bfd65c981424fb14d2b0f42ea0b6618 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Fri, 2 Aug 2024 17:11:21 +0100
Subject: [PATCH] [OpenMP]Update use_device_clause lowering

This patch updates the use_device_ptr and use_device_addr clauses to use the mapInfoOps for lowering. This allows all the types that are handle by the map clauses such as derived types to also be supported by the use_device_clauses.

This is patch 2/2 in a series of patches.
---
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |   2 +-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 284 ++++++++++--------
 mlir/test/Target/LLVMIR/omptarget-llvm.mlir   |  16 +-
 .../openmp-target-use-device-nested.mlir      |  27 ++
 4 files changed, 194 insertions(+), 135 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir

diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 83fec194d73904..f5d94069ad6f4c 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6357,7 +6357,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
   // Disable TargetData CodeGen on Device pass.
   if (Config.IsTargetDevice.value_or(false)) {
     if (BodyGenCB)
-      Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
+      Builder.restoreIP(BodyGenCB(CodeGenIP, BodyGenTy::NoPriv));
     return Builder.saveIP();
   }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 458d05d5059db7..78c460c50cbe5e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2110,6 +2110,8 @@ getRefPtrIfDeclareTarget(mlir::Value value,
 struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
   llvm::SmallVector<bool, 4> IsDeclareTarget;
   llvm::SmallVector<bool, 4> IsAMember;
+  // Identify if mapping was added by mapClause or use_device clauses.
+  llvm::SmallVector<bool, 4> IsAMapping;
   llvm::SmallVector<mlir::Operation *, 4> MapClause;
   llvm::SmallVector<llvm::Value *, 4> OriginalValue;
   // Stripped off array/pointer to get the underlying
@@ -2193,62 +2195,125 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
   return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
 }
 
-void collectMapDataFromMapVars(MapInfoData &mapData,
-                               llvm::SmallVectorImpl<Value> &mapVars,
-                               LLVM::ModuleTranslation &moduleTranslation,
-                               DataLayout &dl, llvm::IRBuilderBase &builder) {
+void collectMapDataFromMapOperands(
+    MapInfoData &mapData, llvm::SmallVectorImpl<Value> &mapVars,
+    LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
+    llvm::IRBuilderBase &builder,
+    const llvm::ArrayRef<Value> &useDevPtrOperands = {},
+    const llvm::ArrayRef<Value> &useDevAddrOperands = {}) {
+  // Process MapOperands
   for (mlir::Value mapValue : mapVars) {
-    if (auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
-            mapValue.getDefiningOp())) {
-      mlir::Value offloadPtr =
-          mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
-      mapData.OriginalValue.push_back(
-          moduleTranslation.lookupValue(offloadPtr));
-      mapData.Pointers.push_back(mapData.OriginalValue.back());
-
-      if (llvm::Value *refPtr =
-              getRefPtrIfDeclareTarget(offloadPtr,
-                                       moduleTranslation)) { // declare target
-        mapData.IsDeclareTarget.push_back(true);
-        mapData.BasePointers.push_back(refPtr);
-      } else { // regular mapped variable
-        mapData.IsDeclareTarget.push_back(false);
-        mapData.BasePointers.push_back(mapData.OriginalValue.back());
-      }
+    auto mapOp = mlir::cast<mlir::omp::MapInfoOp>(mapValue.getDefiningOp());
+    mlir::Value offloadPtr =
+        mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
+    mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
+    mapData.Pointers.push_back(mapData.OriginalValue.back());
+
+    if (llvm::Value *refPtr =
+            getRefPtrIfDeclareTarget(offloadPtr,
+                                     moduleTranslation)) { // declare target
+      mapData.IsDeclareTarget.push_back(true);
+      mapData.BasePointers.push_back(refPtr);
+    } else { // regular mapped variable
+      mapData.IsDeclareTarget.push_back(false);
+      mapData.BasePointers.push_back(mapData.OriginalValue.back());
+    }
 
-      mapData.BaseType.push_back(
-          moduleTranslation.convertType(mapOp.getVarType()));
-      mapData.Sizes.push_back(
-          getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
-                         mapData.BaseType.back(), builder, moduleTranslation));
-      mapData.MapClause.push_back(mapOp.getOperation());
-      mapData.Types.push_back(
-          llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
-      mapData.Names.push_back(LLVM::createMappingInformation(
-          mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
-      mapData.DevicePointers.push_back(
-          llvm::OpenMPIRBuilder::DeviceInfoTy::None);
-
-      // Check if this is a member mapping and correctly assign that it is, if
-      // it is a member of a larger object.
-      // TODO: Need better handling of members, and distinguishing of members
-      // that are implicitly allocated on device vs explicitly passed in as
-      // arguments.
-      // TODO: May require some further additions to support nested record
-      // types, i.e. member maps that can have member maps.
-      mapData.IsAMember.push_back(false);
-      for (mlir::Value mapValue : mapVars) {
-        if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
-                mapValue.getDefiningOp())) {
-          for (auto member : map.getMembers()) {
-            if (member == mapOp) {
-              mapData.IsAMember.back() = true;
-            }
+    mapData.BaseType.push_back(
+        moduleTranslation.convertType(mapOp.getVarType()));
+    mapData.Sizes.push_back(
+        getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
+                       mapData.BaseType.back(), builder, moduleTranslation));
+    mapData.MapClause.push_back(mapOp.getOperation());
+    mapData.Types.push_back(
+        llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
+    mapData.Names.push_back(LLVM::createMappingInformation(
+        mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
+    mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
+    mapData.IsAMapping.push_back(true);
+
+    // Check if this is a member mapping and correctly assign that it is, if
+    // it is a member of a larger object.
+    // TODO: Need better handling of members, and distinguishing of members
+    // that are implicitly allocated on device vs explicitly passed in as
+    // arguments.
+    // TODO: May require some further additions to support nested record
+    // types, i.e. member maps that can have member maps.
+    mapData.IsAMember.push_back(false);
+    for (mlir::Value mapValue : mapVars) {
+      if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
+              mapValue.getDefiningOp())) {
+        for (auto member : map.getMembers()) {
+          if (member == mapOp) {
+            mapData.IsAMember.back() = true;
           }
         }
       }
     }
   }
+
+  auto findMapInfo = [&mapData](llvm::Value *val,
+                                llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
+    unsigned index = 0;
+    bool found = false;
+    for (llvm::Value *basePtr : mapData.OriginalValue) {
+      if (basePtr == val && mapData.IsAMapping[index]) {
+        found = true;
+        mapData.Types[index] |=
+            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+        mapData.DevicePointers[index] = devInfoTy;
+      }
+      index++;
+    }
+    return found;
+  };
+
+  // Process useDevPtr(Addr)Operands
+  auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
+                         llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
+    for (mlir::Value mapValue : useDevOperands) {
+      auto mapOp = mlir::cast<mlir::omp::MapInfoOp>(mapValue.getDefiningOp());
+      mlir::Value offloadPtr =
+          mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
+      llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
+
+      // Check if map info is already present for this entry.
+      if (!findMapInfo(origValue, devInfoTy)) {
+        mapData.OriginalValue.push_back(origValue);
+        mapData.Pointers.push_back(mapData.OriginalValue.back());
+        mapData.IsDeclareTarget.push_back(false);
+        mapData.BasePointers.push_back(mapData.OriginalValue.back());
+        mapData.BaseType.push_back(
+            moduleTranslation.convertType(mapOp.getVarType()));
+        mapData.Sizes.push_back(builder.getInt64(0));
+        mapData.MapClause.push_back(mapOp.getOperation());
+        mapData.Types.push_back(
+            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
+        mapData.Names.push_back(LLVM::createMappingInformation(
+            mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
+        mapData.DevicePointers.push_back(devInfoTy);
+        mapData.IsAMapping.push_back(true);
+
+        // Check if this is a member mapping and correctly assign that it is,
+        // if it is a member of a larger object.
+        // TODO: Need better handling of members, and distinguishing of
+        // members that are implicitly allocated on device vs explicitly
+        // passed in as arguments.
+        // TODO: May require some further additions to support nested record
+        // types, i.e. member maps that can have member maps.
+        mapData.IsAMember.push_back(false);
+        for (mlir::Value mapValue : useDevOperands)
+          if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
+                  mapValue.getDefiningOp()))
+            for (auto member : map.getMembers())
+              if (member == mapOp)
+                mapData.IsAMember.back() = true;
+      }
+    }
+  };
+
+  addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
+  addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
 }
 
 static int getMapDataMemberIdx(MapInfoData &mapData,
@@ -2426,7 +2491,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
           ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
           : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
   combinedInfo.DevicePointers.emplace_back(
-      llvm::OpenMPIRBuilder::DeviceInfoTy::None);
+      mapData.DevicePointers[mapDataIndex]);
   combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
       mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
   combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
@@ -2553,7 +2618,7 @@ static void processMapMembersWithParent(
 
     combinedInfo.Types.emplace_back(mapFlag);
     combinedInfo.DevicePointers.emplace_back(
-        llvm::OpenMPIRBuilder::DeviceInfoTy::None);
+        mapData.DevicePointers[memberDataIdx]);
     combinedInfo.Names.emplace_back(
         LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
     combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
@@ -2714,10 +2779,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
                         LLVM::ModuleTranslation &moduleTranslation,
                         DataLayout &dl,
                         llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
-                        MapInfoData &mapData,
-                        const SmallVector<Value> &useDevicePtrVars = {},
-                        const SmallVector<Value> &useDeviceAddrVars = {},
-                        bool isTargetParams = false) {
+                        MapInfoData &mapData, bool isTargetParams = false) {
   // We wish to modify some of the methods in which arguments are
   // passed based on their capture type by the target region, this can
   // involve generating new loads and stores, which changes the
@@ -2734,15 +2796,6 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
 
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
 
-  auto fail = [&combinedInfo]() -> void {
-    combinedInfo.BasePointers.clear();
-    combinedInfo.Pointers.clear();
-    combinedInfo.DevicePointers.clear();
-    combinedInfo.Sizes.clear();
-    combinedInfo.Types.clear();
-    combinedInfo.Names.clear();
-  };
-
   // We operate under the assumption that all vectors that are
   // required in MapInfoData are of equal lengths (either filled with
   // default constructed data or appropiate information) so we can
@@ -2763,46 +2816,6 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
 
     processIndividualMap(mapData, i, combinedInfo, isTargetParams);
   }
-
-  auto findMapInfo = [&combinedInfo](llvm::Value *val, unsigned &index) {
-    index = 0;
-    for (llvm::Value *basePtr : combinedInfo.BasePointers) {
-      if (basePtr == val)
-        return true;
-      index++;
-    }
-    return false;
-  };
-
-  auto addDevInfos = [&, fail](auto useDeviceVars, auto devOpType) -> void {
-    for (const auto &useDeviceVar : useDeviceVars) {
-      // TODO: Only LLVMPointerTypes are handled.
-      if (!isa<LLVM::LLVMPointerType>(useDeviceVar.getType()))
-        return fail();
-
-      llvm::Value *mapOpValue = moduleTranslation.lookupValue(useDeviceVar);
-
-      // Check if map info is already present for this entry.
-      unsigned infoIndex;
-      if (findMapInfo(mapOpValue, infoIndex)) {
-        combinedInfo.Types[infoIndex] |=
-            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
-        combinedInfo.DevicePointers[infoIndex] = devOpType;
-      } else {
-        combinedInfo.BasePointers.emplace_back(mapOpValue);
-        combinedInfo.Pointers.emplace_back(mapOpValue);
-        combinedInfo.DevicePointers.emplace_back(devOpType);
-        combinedInfo.Names.emplace_back(
-            LLVM::createMappingInformation(useDeviceVar.getLoc(), *ompBuilder));
-        combinedInfo.Types.emplace_back(
-            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
-        combinedInfo.Sizes.emplace_back(builder.getInt64(0));
-      }
-    }
-  };
-
-  addDevInfos(useDevicePtrVars, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
-  addDevInfos(useDeviceAddrVars, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
 }
 
 static LogicalResult
@@ -2899,19 +2912,15 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
 
   MapInfoData mapData;
-  collectMapDataFromMapVars(mapData, mapVars, moduleTranslation, DL, builder);
+  collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
+                                builder, useDevicePtrVars, useDeviceAddrVars);
 
   // Fill up the arrays with all the mapped variables.
   llvm::OpenMPIRBuilder::MapInfosTy combinedInfo;
   auto genMapInfoCB =
       [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
     builder.restoreIP(codeGenIP);
-    if (auto dataOp = dyn_cast<omp::TargetDataOp>(op)) {
-      genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
-                  useDevicePtrVars, useDeviceAddrVars);
-    } else {
-      genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
-    }
+    genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
     return combinedInfo;
   };
 
@@ -2930,21 +2939,23 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
       if (!info.DevicePtrInfoMap.empty()) {
         builder.restoreIP(codeGenIP);
         unsigned argIndex = 0;
-        for (auto &devPtrOp : useDevicePtrVars) {
-          llvm::Value *mapOpValue = moduleTranslation.lookupValue(devPtrOp);
-          const auto &arg = region.front().getArgument(argIndex);
-          moduleTranslation.mapValue(arg,
-                                     info.DevicePtrInfoMap[mapOpValue].second);
-          argIndex++;
-        }
-
-        for (auto &devAddrOp : useDeviceAddrVars) {
-          llvm::Value *mapOpValue = moduleTranslation.lookupValue(devAddrOp);
-          const auto &arg = region.front().getArgument(argIndex);
-          auto *LI = builder.CreateLoad(
-              builder.getPtrTy(), info.DevicePtrInfoMap[mapOpValue].second);
-          moduleTranslation.mapValue(arg, LI);
-          argIndex++;
+        for (size_t i = 0; i < combinedInfo.BasePointers.size(); ++i) {
+          if (combinedInfo.DevicePointers[i] ==
+              llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer) {
+            const auto &arg = region.front().getArgument(argIndex);
+            moduleTranslation.mapValue(
+                arg,
+                info.DevicePtrInfoMap[combinedInfo.BasePointers[i]].second);
+            argIndex++;
+          } else if (combinedInfo.DevicePointers[i] ==
+                     llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
+            const auto &arg = region.front().getArgument(argIndex);
+            auto *loadInst = builder.CreateLoad(
+                builder.getPtrTy(),
+                info.DevicePtrInfoMap[combinedInfo.BasePointers[i]].second);
+            moduleTranslation.mapValue(arg, loadInst);
+            argIndex++;
+          }
         }
 
         bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
@@ -2957,6 +2968,21 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
       // If device info is available then region has already been generated
       if (info.DevicePtrInfoMap.empty()) {
         builder.restoreIP(codeGenIP);
+        // For device pass, if use_device_ptr(addr) mappings were present,
+        // we need to link them here before codegen.
+        if (ompBuilder->Config.IsTargetDevice.value_or(false)) {
+          unsigned argIndex = 0;
+          for (size_t i = 0; i < mapData.BasePointers.size(); ++i) {
+            if (mapData.DevicePointers[i] ==
+                    llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer ||
+                mapData.DevicePointers[i] ==
+                    llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
+              const auto &arg = region.front().getArgument(argIndex);
+              moduleTranslation.mapValue(arg, mapData.BasePointers[i]);
+              argIndex++;
+            }
+          }
+        }
         bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
                                                 builder, moduleTranslation);
       }
@@ -3299,14 +3325,14 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
       findAllocaInsertPoint(builder, moduleTranslation);
 
   MapInfoData mapData;
-  collectMapDataFromMapVars(mapData, mapVars, moduleTranslation, dl, builder);
+  collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
+                                builder);
 
   llvm::OpenMPIRBuilder::MapInfosTy combinedInfos;
   auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
       -> llvm::OpenMPIRBuilder::MapInfosTy & {
     builder.restoreIP(codeGenIP);
-    genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, {}, {},
-                true);
+    genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, true);
     return combinedInfos;
   };
 
diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
index bf9fa183bfb802..458d2f28a78f8d 100644
--- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
@@ -209,7 +209,8 @@ llvm.func @_QPopenmp_target_use_dev_ptr() {
   %0 = llvm.mlir.constant(1 : i64) : i64
   %a = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr
   %map1 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr)   map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
-  omp.target_data  map_entries(%map1 : !llvm.ptr) use_device_ptr(%a : !llvm.ptr)  {
+  %map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr)   map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
+  omp.target_data  map_entries(%map1 : !llvm.ptr) use_device_ptr(%map2 : !llvm.ptr)  {
   ^bb0(%arg0: !llvm.ptr):
     %1 = llvm.mlir.constant(10 : i32) : i32
     %2 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr
@@ -253,7 +254,8 @@ llvm.func @_QPopenmp_target_use_dev_addr() {
   %0 = llvm.mlir.constant(1 : i64) : i64
   %a = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr
   %map = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr)   map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
-  omp.target_data  map_entries(%map : !llvm.ptr) use_device_addr(%a : !llvm.ptr)  {
+  %map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr)   map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
+  omp.target_data  map_entries(%map : !llvm.ptr) use_device_addr(%map2 : !llvm.ptr)  {
   ^bb0(%arg0: !llvm.ptr):
     %1 = llvm.mlir.constant(10 : i32) : i32
     %2 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr
@@ -295,7 +297,8 @@ llvm.func @_QPopenmp_target_use_dev_addr_no_ptr() {
   %0 = llvm.mlir.constant(1 : i64) : i64
   %a = llvm.alloca %0 x i32 : (i64) -> !llvm.ptr
   %map = omp.map.info var_ptr(%a : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
-  omp.target_data  map_entries(%map : !llvm.ptr) use_device_addr(%a : !llvm.ptr)  {
+  %map2 = omp.map.info var_ptr(%a : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+  omp.target_data  map_entries(%map : !llvm.ptr) use_device_addr(%map2 : !llvm.ptr)  {
   ^bb0(%arg0: !llvm.ptr):
     %1 = llvm.mlir.constant(10 : i32) : i32
     llvm.store %1, %arg0 : i32, !llvm.ptr
@@ -337,7 +340,8 @@ llvm.func @_QPopenmp_target_use_dev_addr_nomap() {
   %1 = llvm.mlir.constant(1 : i64) : i64
   %b = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr
   %map = omp.map.info var_ptr(%b : !llvm.ptr, !llvm.ptr)   map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
-  omp.target_data  map_entries(%map : !llvm.ptr) use_device_addr(%a : !llvm.ptr)  {
+  %map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+  omp.target_data  map_entries(%map : !llvm.ptr) use_device_addr(%map2 : !llvm.ptr)  {
   ^bb0(%arg0: !llvm.ptr):
     %2 = llvm.mlir.constant(10 : i32) : i32
     %3 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr
@@ -394,7 +398,9 @@ llvm.func @_QPopenmp_target_use_dev_both() {
   %b = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr
   %map = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
   %map1 = omp.map.info var_ptr(%b : !llvm.ptr, !llvm.ptr)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
-  omp.target_data  map_entries(%map, %map1 : !llvm.ptr, !llvm.ptr) use_device_ptr(%a : !llvm.ptr) use_device_addr(%b : !llvm.ptr)  {
+  %map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+  %map3 = omp.map.info var_ptr(%b : !llvm.ptr, !llvm.ptr)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+  omp.target_data  map_entries(%map, %map1 : !llvm.ptr, !llvm.ptr) use_device_ptr(%map2 : !llvm.ptr) use_device_addr(%map3 : !llvm.ptr)  {
   ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
     %2 = llvm.mlir.constant(10 : i32) : i32
     %3 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr
diff --git a/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir b/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir
new file mode 100644
index 00000000000000..bcca2af985b322
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// This tests check that target code nested inside a target data region which
+// has only use_device_ptr mapping corectly generates code on the device pass.
+
+// CHECK-NOT: call void @__tgt_target_data_begin_mapper
+// CHECK: store i32 999, ptr {{.*}}
+module attributes {omp.is_target_device = true } {
+  llvm.func @_QQmain() attributes {fir.bindc_name = "main"} {
+    %0 = llvm.mlir.constant(1 : i64) : i64
+    %a = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr
+    %map = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+    omp.target_data use_device_ptr(%map : !llvm.ptr)  {
+    ^bb0(%arg0: !llvm.ptr):
+      %map1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.ptr)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+      omp.target map_entries(%map1 : !llvm.ptr){
+      ^bb0(%arg1: !llvm.ptr):
+        %1 = llvm.mlir.constant(999 : i32) : i32
+        %2 = llvm.load %arg1 : !llvm.ptr -> !llvm.ptr
+        llvm.store %1, %2 : i32, !llvm.ptr
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}



More information about the llvm-branch-commits mailing list