[llvm-branch-commits] [llvm] [mlir] [OpenMP]Update use_device_clause lowering (PR #101707)
Akash Banerjee via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Aug 21 10:48:56 PDT 2024
https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/101707
>From 547b339b175fa996eef8d45c5df8a73967ee94c2 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Fri, 2 Aug 2024 17:11:21 +0100
Subject: [PATCH 1/3] [OpenMP]Update use_device_clause lowering
This patch updates the use_device_ptr and use_device_addr clauses to use the mapInfoOps for lowering. This allows all the types that are handle by the map clauses such as derived types to also be supported by the use_device_clauses.
This is patch 2/2 in a series of patches.
---
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 2 +-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 284 ++++++++++--------
mlir/test/Target/LLVMIR/omptarget-llvm.mlir | 16 +-
.../openmp-target-use-device-nested.mlir | 27 ++
4 files changed, 194 insertions(+), 135 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 83fec194d73904..f5d94069ad6f4c 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6357,7 +6357,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
// Disable TargetData CodeGen on Device pass.
if (Config.IsTargetDevice.value_or(false)) {
if (BodyGenCB)
- Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
+ Builder.restoreIP(BodyGenCB(CodeGenIP, BodyGenTy::NoPriv));
return Builder.saveIP();
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 458d05d5059db7..78c460c50cbe5e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2110,6 +2110,8 @@ getRefPtrIfDeclareTarget(mlir::Value value,
struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
llvm::SmallVector<bool, 4> IsDeclareTarget;
llvm::SmallVector<bool, 4> IsAMember;
+ // Identify if mapping was added by mapClause or use_device clauses.
+ llvm::SmallVector<bool, 4> IsAMapping;
llvm::SmallVector<mlir::Operation *, 4> MapClause;
llvm::SmallVector<llvm::Value *, 4> OriginalValue;
// Stripped off array/pointer to get the underlying
@@ -2193,62 +2195,125 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
}
-void collectMapDataFromMapVars(MapInfoData &mapData,
- llvm::SmallVectorImpl<Value> &mapVars,
- LLVM::ModuleTranslation &moduleTranslation,
- DataLayout &dl, llvm::IRBuilderBase &builder) {
+void collectMapDataFromMapOperands(
+ MapInfoData &mapData, llvm::SmallVectorImpl<Value> &mapVars,
+ LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
+ llvm::IRBuilderBase &builder,
+ const llvm::ArrayRef<Value> &useDevPtrOperands = {},
+ const llvm::ArrayRef<Value> &useDevAddrOperands = {}) {
+ // Process MapOperands
for (mlir::Value mapValue : mapVars) {
- if (auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
- mapValue.getDefiningOp())) {
- mlir::Value offloadPtr =
- mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
- mapData.OriginalValue.push_back(
- moduleTranslation.lookupValue(offloadPtr));
- mapData.Pointers.push_back(mapData.OriginalValue.back());
-
- if (llvm::Value *refPtr =
- getRefPtrIfDeclareTarget(offloadPtr,
- moduleTranslation)) { // declare target
- mapData.IsDeclareTarget.push_back(true);
- mapData.BasePointers.push_back(refPtr);
- } else { // regular mapped variable
- mapData.IsDeclareTarget.push_back(false);
- mapData.BasePointers.push_back(mapData.OriginalValue.back());
- }
+ auto mapOp = mlir::cast<mlir::omp::MapInfoOp>(mapValue.getDefiningOp());
+ mlir::Value offloadPtr =
+ mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
+ mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
+ mapData.Pointers.push_back(mapData.OriginalValue.back());
+
+ if (llvm::Value *refPtr =
+ getRefPtrIfDeclareTarget(offloadPtr,
+ moduleTranslation)) { // declare target
+ mapData.IsDeclareTarget.push_back(true);
+ mapData.BasePointers.push_back(refPtr);
+ } else { // regular mapped variable
+ mapData.IsDeclareTarget.push_back(false);
+ mapData.BasePointers.push_back(mapData.OriginalValue.back());
+ }
- mapData.BaseType.push_back(
- moduleTranslation.convertType(mapOp.getVarType()));
- mapData.Sizes.push_back(
- getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
- mapData.BaseType.back(), builder, moduleTranslation));
- mapData.MapClause.push_back(mapOp.getOperation());
- mapData.Types.push_back(
- llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
- mapData.Names.push_back(LLVM::createMappingInformation(
- mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
- mapData.DevicePointers.push_back(
- llvm::OpenMPIRBuilder::DeviceInfoTy::None);
-
- // Check if this is a member mapping and correctly assign that it is, if
- // it is a member of a larger object.
- // TODO: Need better handling of members, and distinguishing of members
- // that are implicitly allocated on device vs explicitly passed in as
- // arguments.
- // TODO: May require some further additions to support nested record
- // types, i.e. member maps that can have member maps.
- mapData.IsAMember.push_back(false);
- for (mlir::Value mapValue : mapVars) {
- if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
- mapValue.getDefiningOp())) {
- for (auto member : map.getMembers()) {
- if (member == mapOp) {
- mapData.IsAMember.back() = true;
- }
+ mapData.BaseType.push_back(
+ moduleTranslation.convertType(mapOp.getVarType()));
+ mapData.Sizes.push_back(
+ getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
+ mapData.BaseType.back(), builder, moduleTranslation));
+ mapData.MapClause.push_back(mapOp.getOperation());
+ mapData.Types.push_back(
+ llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
+ mapData.Names.push_back(LLVM::createMappingInformation(
+ mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
+ mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
+ mapData.IsAMapping.push_back(true);
+
+ // Check if this is a member mapping and correctly assign that it is, if
+ // it is a member of a larger object.
+ // TODO: Need better handling of members, and distinguishing of members
+ // that are implicitly allocated on device vs explicitly passed in as
+ // arguments.
+ // TODO: May require some further additions to support nested record
+ // types, i.e. member maps that can have member maps.
+ mapData.IsAMember.push_back(false);
+ for (mlir::Value mapValue : mapVars) {
+ if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
+ mapValue.getDefiningOp())) {
+ for (auto member : map.getMembers()) {
+ if (member == mapOp) {
+ mapData.IsAMember.back() = true;
}
}
}
}
}
+
+ auto findMapInfo = [&mapData](llvm::Value *val,
+ llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
+ unsigned index = 0;
+ bool found = false;
+ for (llvm::Value *basePtr : mapData.OriginalValue) {
+ if (basePtr == val && mapData.IsAMapping[index]) {
+ found = true;
+ mapData.Types[index] |=
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+ mapData.DevicePointers[index] = devInfoTy;
+ }
+ index++;
+ }
+ return found;
+ };
+
+ // Process useDevPtr(Addr)Operands
+ auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
+ llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
+ for (mlir::Value mapValue : useDevOperands) {
+ auto mapOp = mlir::cast<mlir::omp::MapInfoOp>(mapValue.getDefiningOp());
+ mlir::Value offloadPtr =
+ mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
+ llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
+
+ // Check if map info is already present for this entry.
+ if (!findMapInfo(origValue, devInfoTy)) {
+ mapData.OriginalValue.push_back(origValue);
+ mapData.Pointers.push_back(mapData.OriginalValue.back());
+ mapData.IsDeclareTarget.push_back(false);
+ mapData.BasePointers.push_back(mapData.OriginalValue.back());
+ mapData.BaseType.push_back(
+ moduleTranslation.convertType(mapOp.getVarType()));
+ mapData.Sizes.push_back(builder.getInt64(0));
+ mapData.MapClause.push_back(mapOp.getOperation());
+ mapData.Types.push_back(
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
+ mapData.Names.push_back(LLVM::createMappingInformation(
+ mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
+ mapData.DevicePointers.push_back(devInfoTy);
+ mapData.IsAMapping.push_back(true);
+
+ // Check if this is a member mapping and correctly assign that it is,
+ // if it is a member of a larger object.
+ // TODO: Need better handling of members, and distinguishing of
+ // members that are implicitly allocated on device vs explicitly
+ // passed in as arguments.
+ // TODO: May require some further additions to support nested record
+ // types, i.e. member maps that can have member maps.
+ mapData.IsAMember.push_back(false);
+ for (mlir::Value mapValue : useDevOperands)
+ if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
+ mapValue.getDefiningOp()))
+ for (auto member : map.getMembers())
+ if (member == mapOp)
+ mapData.IsAMember.back() = true;
+ }
+ }
+ };
+
+ addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
+ addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
}
static int getMapDataMemberIdx(MapInfoData &mapData,
@@ -2426,7 +2491,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
combinedInfo.DevicePointers.emplace_back(
- llvm::OpenMPIRBuilder::DeviceInfoTy::None);
+ mapData.DevicePointers[mapDataIndex]);
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
@@ -2553,7 +2618,7 @@ static void processMapMembersWithParent(
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.DevicePointers.emplace_back(
- llvm::OpenMPIRBuilder::DeviceInfoTy::None);
+ mapData.DevicePointers[memberDataIdx]);
combinedInfo.Names.emplace_back(
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
@@ -2714,10 +2779,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
DataLayout &dl,
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
- MapInfoData &mapData,
- const SmallVector<Value> &useDevicePtrVars = {},
- const SmallVector<Value> &useDeviceAddrVars = {},
- bool isTargetParams = false) {
+ MapInfoData &mapData, bool isTargetParams = false) {
// We wish to modify some of the methods in which arguments are
// passed based on their capture type by the target region, this can
// involve generating new loads and stores, which changes the
@@ -2734,15 +2796,6 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
- auto fail = [&combinedInfo]() -> void {
- combinedInfo.BasePointers.clear();
- combinedInfo.Pointers.clear();
- combinedInfo.DevicePointers.clear();
- combinedInfo.Sizes.clear();
- combinedInfo.Types.clear();
- combinedInfo.Names.clear();
- };
-
// We operate under the assumption that all vectors that are
// required in MapInfoData are of equal lengths (either filled with
// default constructed data or appropiate information) so we can
@@ -2763,46 +2816,6 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
processIndividualMap(mapData, i, combinedInfo, isTargetParams);
}
-
- auto findMapInfo = [&combinedInfo](llvm::Value *val, unsigned &index) {
- index = 0;
- for (llvm::Value *basePtr : combinedInfo.BasePointers) {
- if (basePtr == val)
- return true;
- index++;
- }
- return false;
- };
-
- auto addDevInfos = [&, fail](auto useDeviceVars, auto devOpType) -> void {
- for (const auto &useDeviceVar : useDeviceVars) {
- // TODO: Only LLVMPointerTypes are handled.
- if (!isa<LLVM::LLVMPointerType>(useDeviceVar.getType()))
- return fail();
-
- llvm::Value *mapOpValue = moduleTranslation.lookupValue(useDeviceVar);
-
- // Check if map info is already present for this entry.
- unsigned infoIndex;
- if (findMapInfo(mapOpValue, infoIndex)) {
- combinedInfo.Types[infoIndex] |=
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
- combinedInfo.DevicePointers[infoIndex] = devOpType;
- } else {
- combinedInfo.BasePointers.emplace_back(mapOpValue);
- combinedInfo.Pointers.emplace_back(mapOpValue);
- combinedInfo.DevicePointers.emplace_back(devOpType);
- combinedInfo.Names.emplace_back(
- LLVM::createMappingInformation(useDeviceVar.getLoc(), *ompBuilder));
- combinedInfo.Types.emplace_back(
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
- combinedInfo.Sizes.emplace_back(builder.getInt64(0));
- }
- }
- };
-
- addDevInfos(useDevicePtrVars, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
- addDevInfos(useDeviceAddrVars, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
}
static LogicalResult
@@ -2899,19 +2912,15 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
MapInfoData mapData;
- collectMapDataFromMapVars(mapData, mapVars, moduleTranslation, DL, builder);
+ collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, DL,
+ builder, useDevicePtrVars, useDeviceAddrVars);
// Fill up the arrays with all the mapped variables.
llvm::OpenMPIRBuilder::MapInfosTy combinedInfo;
auto genMapInfoCB =
[&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
builder.restoreIP(codeGenIP);
- if (auto dataOp = dyn_cast<omp::TargetDataOp>(op)) {
- genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
- useDevicePtrVars, useDeviceAddrVars);
- } else {
- genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
- }
+ genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
return combinedInfo;
};
@@ -2930,21 +2939,23 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
if (!info.DevicePtrInfoMap.empty()) {
builder.restoreIP(codeGenIP);
unsigned argIndex = 0;
- for (auto &devPtrOp : useDevicePtrVars) {
- llvm::Value *mapOpValue = moduleTranslation.lookupValue(devPtrOp);
- const auto &arg = region.front().getArgument(argIndex);
- moduleTranslation.mapValue(arg,
- info.DevicePtrInfoMap[mapOpValue].second);
- argIndex++;
- }
-
- for (auto &devAddrOp : useDeviceAddrVars) {
- llvm::Value *mapOpValue = moduleTranslation.lookupValue(devAddrOp);
- const auto &arg = region.front().getArgument(argIndex);
- auto *LI = builder.CreateLoad(
- builder.getPtrTy(), info.DevicePtrInfoMap[mapOpValue].second);
- moduleTranslation.mapValue(arg, LI);
- argIndex++;
+ for (size_t i = 0; i < combinedInfo.BasePointers.size(); ++i) {
+ if (combinedInfo.DevicePointers[i] ==
+ llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer) {
+ const auto &arg = region.front().getArgument(argIndex);
+ moduleTranslation.mapValue(
+ arg,
+ info.DevicePtrInfoMap[combinedInfo.BasePointers[i]].second);
+ argIndex++;
+ } else if (combinedInfo.DevicePointers[i] ==
+ llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
+ const auto &arg = region.front().getArgument(argIndex);
+ auto *loadInst = builder.CreateLoad(
+ builder.getPtrTy(),
+ info.DevicePtrInfoMap[combinedInfo.BasePointers[i]].second);
+ moduleTranslation.mapValue(arg, loadInst);
+ argIndex++;
+ }
}
bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
@@ -2957,6 +2968,21 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
// If device info is available then region has already been generated
if (info.DevicePtrInfoMap.empty()) {
builder.restoreIP(codeGenIP);
+ // For device pass, if use_device_ptr(addr) mappings were present,
+ // we need to link them here before codegen.
+ if (ompBuilder->Config.IsTargetDevice.value_or(false)) {
+ unsigned argIndex = 0;
+ for (size_t i = 0; i < mapData.BasePointers.size(); ++i) {
+ if (mapData.DevicePointers[i] ==
+ llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer ||
+ mapData.DevicePointers[i] ==
+ llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
+ const auto &arg = region.front().getArgument(argIndex);
+ moduleTranslation.mapValue(arg, mapData.BasePointers[i]);
+ argIndex++;
+ }
+ }
+ }
bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
builder, moduleTranslation);
}
@@ -3299,14 +3325,14 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
findAllocaInsertPoint(builder, moduleTranslation);
MapInfoData mapData;
- collectMapDataFromMapVars(mapData, mapVars, moduleTranslation, dl, builder);
+ collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
+ builder);
llvm::OpenMPIRBuilder::MapInfosTy combinedInfos;
auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
-> llvm::OpenMPIRBuilder::MapInfosTy & {
builder.restoreIP(codeGenIP);
- genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, {}, {},
- true);
+ genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, true);
return combinedInfos;
};
diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
index bf9fa183bfb802..458d2f28a78f8d 100644
--- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
@@ -209,7 +209,8 @@ llvm.func @_QPopenmp_target_use_dev_ptr() {
%0 = llvm.mlir.constant(1 : i64) : i64
%a = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr
%map1 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
- omp.target_data map_entries(%map1 : !llvm.ptr) use_device_ptr(%a : !llvm.ptr) {
+ %map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
+ omp.target_data map_entries(%map1 : !llvm.ptr) use_device_ptr(%map2 : !llvm.ptr) {
^bb0(%arg0: !llvm.ptr):
%1 = llvm.mlir.constant(10 : i32) : i32
%2 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr
@@ -253,7 +254,8 @@ llvm.func @_QPopenmp_target_use_dev_addr() {
%0 = llvm.mlir.constant(1 : i64) : i64
%a = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr
%map = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
- omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%a : !llvm.ptr) {
+ %map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
+ omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%map2 : !llvm.ptr) {
^bb0(%arg0: !llvm.ptr):
%1 = llvm.mlir.constant(10 : i32) : i32
%2 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr
@@ -295,7 +297,8 @@ llvm.func @_QPopenmp_target_use_dev_addr_no_ptr() {
%0 = llvm.mlir.constant(1 : i64) : i64
%a = llvm.alloca %0 x i32 : (i64) -> !llvm.ptr
%map = omp.map.info var_ptr(%a : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
- omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%a : !llvm.ptr) {
+ %map2 = omp.map.info var_ptr(%a : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+ omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%map2 : !llvm.ptr) {
^bb0(%arg0: !llvm.ptr):
%1 = llvm.mlir.constant(10 : i32) : i32
llvm.store %1, %arg0 : i32, !llvm.ptr
@@ -337,7 +340,8 @@ llvm.func @_QPopenmp_target_use_dev_addr_nomap() {
%1 = llvm.mlir.constant(1 : i64) : i64
%b = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr
%map = omp.map.info var_ptr(%b : !llvm.ptr, !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
- omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%a : !llvm.ptr) {
+ %map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+ omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%map2 : !llvm.ptr) {
^bb0(%arg0: !llvm.ptr):
%2 = llvm.mlir.constant(10 : i32) : i32
%3 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr
@@ -394,7 +398,9 @@ llvm.func @_QPopenmp_target_use_dev_both() {
%b = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr
%map = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
%map1 = omp.map.info var_ptr(%b : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
- omp.target_data map_entries(%map, %map1 : !llvm.ptr, !llvm.ptr) use_device_ptr(%a : !llvm.ptr) use_device_addr(%b : !llvm.ptr) {
+ %map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+ %map3 = omp.map.info var_ptr(%b : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+ omp.target_data map_entries(%map, %map1 : !llvm.ptr, !llvm.ptr) use_device_ptr(%map2 : !llvm.ptr) use_device_addr(%map3 : !llvm.ptr) {
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
%2 = llvm.mlir.constant(10 : i32) : i32
%3 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr
diff --git a/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir b/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir
new file mode 100644
index 00000000000000..bcca2af985b322
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// This tests check that target code nested inside a target data region which
+// has only use_device_ptr mapping corectly generates code on the device pass.
+
+// CHECK-NOT: call void @__tgt_target_data_begin_mapper
+// CHECK: store i32 999, ptr {{.*}}
+module attributes {omp.is_target_device = true } {
+ llvm.func @_QQmain() attributes {fir.bindc_name = "main"} {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %a = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr
+ %map = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+ omp.target_data use_device_ptr(%map : !llvm.ptr) {
+ ^bb0(%arg0: !llvm.ptr):
+ %map1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+ omp.target map_entries(%map1 : !llvm.ptr){
+ ^bb0(%arg1: !llvm.ptr):
+ %1 = llvm.mlir.constant(999 : i32) : i32
+ %2 = llvm.load %arg1 : !llvm.ptr -> !llvm.ptr
+ llvm.store %1, %2 : i32, !llvm.ptr
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+ }
+}
>From 9f346458b8ea8915bef91faf93477cfc50b3fe94 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Tue, 20 Aug 2024 15:12:50 +0100
Subject: [PATCH 2/3] Address reviewer comments.
---
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 96 +++++++++----------
1 file changed, 46 insertions(+), 50 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 78c460c50cbe5e..e56dfdeb265117 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2101,6 +2101,7 @@ getRefPtrIfDeclareTarget(mlir::Value value,
return nullptr;
}
+namespace {
// A small helper structure to contain data gathered
// for map lowering and coalese it into one area and
// avoiding extra computations such as searches in the
@@ -2129,6 +2130,7 @@ struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo);
}
};
+} // namespace
uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl) {
if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
@@ -2195,16 +2197,15 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
}
-void collectMapDataFromMapOperands(
- MapInfoData &mapData, llvm::SmallVectorImpl<Value> &mapVars,
+static void collectMapDataFromMapOperands(
+ MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
- llvm::IRBuilderBase &builder,
- const llvm::ArrayRef<Value> &useDevPtrOperands = {},
- const llvm::ArrayRef<Value> &useDevAddrOperands = {}) {
+ llvm::IRBuilderBase &builder, const ArrayRef<Value> &useDevPtrOperands = {},
+ const ArrayRef<Value> &useDevAddrOperands = {}) {
// Process MapOperands
- for (mlir::Value mapValue : mapVars) {
- auto mapOp = mlir::cast<mlir::omp::MapInfoOp>(mapValue.getDefiningOp());
- mlir::Value offloadPtr =
+ for (Value mapValue : mapVars) {
+ auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
+ Value offloadPtr =
mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
mapData.Pointers.push_back(mapData.OriginalValue.back());
@@ -2240,9 +2241,9 @@ void collectMapDataFromMapOperands(
// TODO: May require some further additions to support nested record
// types, i.e. member maps that can have member maps.
mapData.IsAMember.push_back(false);
- for (mlir::Value mapValue : mapVars) {
- if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
- mapValue.getDefiningOp())) {
+ for (Value mapValue : mapVars) {
+ if (auto map =
+ dyn_cast_if_present<omp::MapInfoOp>(mapValue.getDefiningOp())) {
for (auto member : map.getMembers()) {
if (member == mapOp) {
mapData.IsAMember.back() = true;
@@ -2271,9 +2272,9 @@ void collectMapDataFromMapOperands(
// Process useDevPtr(Addr)Operands
auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
- for (mlir::Value mapValue : useDevOperands) {
- auto mapOp = mlir::cast<mlir::omp::MapInfoOp>(mapValue.getDefiningOp());
- mlir::Value offloadPtr =
+ for (Value mapValue : useDevOperands) {
+ auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
+ Value offloadPtr =
mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
@@ -2302,9 +2303,9 @@ void collectMapDataFromMapOperands(
// TODO: May require some further additions to support nested record
// types, i.e. member maps that can have member maps.
mapData.IsAMember.push_back(false);
- for (mlir::Value mapValue : useDevOperands)
- if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
- mapValue.getDefiningOp()))
+ for (Value mapValue : useDevOperands)
+ if (auto map =
+ dyn_cast_if_present<omp::MapInfoOp>(mapValue.getDefiningOp()))
for (auto member : map.getMembers())
if (member == mapOp)
mapData.IsAMember.back() = true;
@@ -2316,22 +2317,21 @@ void collectMapDataFromMapOperands(
addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
}
-static int getMapDataMemberIdx(MapInfoData &mapData,
- mlir::omp::MapInfoOp memberOp) {
+static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
auto *res = llvm::find(mapData.MapClause, memberOp);
assert(res != mapData.MapClause.end() &&
"MapInfoOp for member not found in MapData, cannot return index");
return std::distance(mapData.MapClause.begin(), res);
}
-static mlir::omp::MapInfoOp
-getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
- mlir::DenseIntElementsAttr indexAttr = mapInfo.getMembersIndexAttr();
+static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
+ bool first) {
+ DenseIntElementsAttr indexAttr = mapInfo.getMembersIndexAttr();
// Only 1 member has been mapped, we can return it.
if (indexAttr.size() == 1)
- if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
- mapInfo.getMembers()[0].getDefiningOp()))
+ if (auto mapOp =
+ dyn_cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp()))
return mapOp;
llvm::ArrayRef<int64_t> shape = indexAttr.getShapedType().getShape();
@@ -2368,7 +2368,7 @@ getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
return false;
});
- return llvm::cast<mlir::omp::MapInfoOp>(
+ return llvm::cast<omp::MapInfoOp>(
mapInfo.getMembers()[indices.front()].getDefiningOp());
}
@@ -2394,7 +2394,7 @@ getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
std::vector<llvm::Value *>
calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
llvm::IRBuilderBase &builder, bool isArrayTy,
- mlir::OperandRange bounds) {
+ OperandRange bounds) {
std::vector<llvm::Value *> idx;
// There's no bounds to calculate an offset from, we can safely
// ignore and return no indices.
@@ -2408,7 +2408,7 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
if (isArrayTy) {
idx.push_back(builder.getInt64(0));
for (int i = bounds.size() - 1; i >= 0; --i) {
- if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
+ if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
bounds[i].getDefiningOp())) {
idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
}
@@ -2434,7 +2434,7 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
// (extent/size of current) 100 for 1000 for each index increment
std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
for (size_t i = 1; i < bounds.size(); ++i) {
- if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
+ if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
bounds[i].getDefiningOp())) {
dimensionIndexSizeOffset.push_back(builder.CreateMul(
moduleTranslation.lookupValue(boundOp.getExtent()),
@@ -2447,7 +2447,7 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
// have calculated in the previous and accumulate the results to get
// our final resulting offset.
for (int i = bounds.size() - 1; i >= 0; --i) {
- if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
+ if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
bounds[i].getDefiningOp())) {
if (idx.empty())
idx.emplace_back(builder.CreateMul(
@@ -2504,7 +2504,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
// data by the descriptor (which itself, is a structure containing
// runtime information on the dynamically allocated data).
auto parentClause =
- llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
+ llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
llvm::Value *lowAddr, *highAddr;
if (!parentClause.getPartialMap()) {
@@ -2516,8 +2516,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
builder.getPtrTy());
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
} else {
- auto mapOp =
- mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
+ auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
int firstMemberIdx = getMapDataMemberIdx(
mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
@@ -2575,7 +2574,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
// There may be a better way to verify this, but unfortunately with
// opaque pointers we lose the ability to easily check if something is
// a pointer whilst maintaining access to the underlying type.
-static bool checkIfPointerMap(mlir::omp::MapInfoOp mapOp) {
+static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
// If we have a varPtrPtr field assigned then the underlying type is a pointer
if (mapOp.getVarPtrPtr())
return true;
@@ -2597,11 +2596,11 @@ static void processMapMembersWithParent(
uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
auto parentClause =
- llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
+ llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
for (auto mappedMembers : parentClause.getMembers()) {
auto memberClause =
- llvm::cast<mlir::omp::MapInfoOp>(mappedMembers.getDefiningOp());
+ llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
assert(memberDataIdx >= 0 && "could not find mapped member of structure");
@@ -2635,8 +2634,7 @@ processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
// OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
// marked with OMP_MAP_PTR_AND_OBJ instead.
auto mapFlag = mapData.Types[mapDataIdx];
- auto mapInfoOp =
- llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
+ auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
bool isPtrTy = checkIfPointerMap(mapInfoOp);
if (isPtrTy)
@@ -2646,7 +2644,7 @@ processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
if (mapInfoOp.getMapCaptureType().value() ==
- mlir::omp::VariableCaptureKind::ByCopy &&
+ omp::VariableCaptureKind::ByCopy &&
!isPtrTy)
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
@@ -2672,13 +2670,13 @@ static void processMapWithMembersOf(
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
uint64_t mapDataIndex, bool isTargetParams) {
auto parentClause =
- llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
+ llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
// If we have a partial map (no parent referenced in the map clauses of the
// directive, only members) and only a single member, we do not need to bind
// the map of the member to the parent, we can pass the member separately.
if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
- auto memberClause = llvm::cast<mlir::omp::MapInfoOp>(
+ auto memberClause = llvm::cast<omp::MapInfoOp>(
parentClause.getMembers()[0].getDefiningOp());
int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
// Note: Clang treats arrays with explicit bounds that fall into this
@@ -2715,11 +2713,9 @@ createAlteredByCaptureMap(MapInfoData &mapData,
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
// if it's declare target, skip it, it's handled separately.
if (!mapData.IsDeclareTarget[i]) {
- auto mapOp =
- mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(mapData.MapClause[i]);
- mlir::omp::VariableCaptureKind captureKind =
- mapOp.getMapCaptureType().value_or(
- mlir::omp::VariableCaptureKind::ByRef);
+ auto mapOp = dyn_cast_if_present<omp::MapInfoOp>(mapData.MapClause[i]);
+ omp::VariableCaptureKind captureKind =
+ mapOp.getMapCaptureType().value_or(omp::VariableCaptureKind::ByRef);
bool isPtrTy = checkIfPointerMap(mapOp);
// Currently handles array sectioning lowerbound case, but more
@@ -2730,7 +2726,7 @@ createAlteredByCaptureMap(MapInfoData &mapData,
// function mimics some of the logic from Clang that we require for
// kernel argument passing from host -> device.
switch (captureKind) {
- case mlir::omp::VariableCaptureKind::ByRef: {
+ case omp::VariableCaptureKind::ByRef: {
llvm::Value *newV = mapData.Pointers[i];
std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
@@ -2743,7 +2739,7 @@ createAlteredByCaptureMap(MapInfoData &mapData,
"array_offset");
mapData.Pointers[i] = newV;
} break;
- case mlir::omp::VariableCaptureKind::ByCopy: {
+ case omp::VariableCaptureKind::ByCopy: {
llvm::Type *type = mapData.BaseType[i];
llvm::Value *newV;
if (mapData.Pointers[i]->getType()->isPointerTy())
@@ -2765,8 +2761,8 @@ createAlteredByCaptureMap(MapInfoData &mapData,
mapData.Pointers[i] = newV;
mapData.BasePointers[i] = newV;
} break;
- case mlir::omp::VariableCaptureKind::This:
- case mlir::omp::VariableCaptureKind::VLAType:
+ case omp::VariableCaptureKind::This:
+ case omp::VariableCaptureKind::VLAType:
mapData.MapClause[i]->emitOpError("Unhandled capture kind");
break;
}
@@ -2807,7 +2803,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
if (mapData.IsAMember[i])
continue;
- auto mapInfoOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[i]);
+ auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
if (!mapInfoOp.getMembers().empty()) {
processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
combinedInfo, mapData, i, isTargetParams);
>From 9cb50454c363fd45ba3be6f03467ae9312338941 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Wed, 21 Aug 2024 18:48:05 +0100
Subject: [PATCH 3/3] Addressed reviewer comments.
---
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 4 +-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 73 ++++++++-----------
.../openmp-target-use-device-nested.mlir | 16 +++-
3 files changed, 47 insertions(+), 46 deletions(-)
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index f5d94069ad6f4c..6f8d3e53d5eed4 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6354,14 +6354,14 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
if (!updateToLocation(Loc))
return InsertPointTy();
+ Builder.restoreIP(CodeGenIP);
// Disable TargetData CodeGen on Device pass.
if (Config.IsTargetDevice.value_or(false)) {
if (BodyGenCB)
- Builder.restoreIP(BodyGenCB(CodeGenIP, BodyGenTy::NoPriv));
+ Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
return Builder.saveIP();
}
- Builder.restoreIP(CodeGenIP);
bool IsStandAlone = !BodyGenCB;
MapInfosTy *MapInfo;
// Generate the code for the opening of the data environment. Capture all the
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index e56dfdeb265117..6a8b4ac3505cff 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2242,14 +2242,10 @@ static void collectMapDataFromMapOperands(
// types, i.e. member maps that can have member maps.
mapData.IsAMember.push_back(false);
for (Value mapValue : mapVars) {
- if (auto map =
- dyn_cast_if_present<omp::MapInfoOp>(mapValue.getDefiningOp())) {
- for (auto member : map.getMembers()) {
- if (member == mapOp) {
- mapData.IsAMember.back() = true;
- }
- }
- }
+ auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
+ for (auto member : map.getMembers())
+ if (member == mapOp)
+ mapData.IsAMember.back() = true;
}
}
@@ -2303,12 +2299,12 @@ static void collectMapDataFromMapOperands(
// TODO: May require some further additions to support nested record
// types, i.e. member maps that can have member maps.
mapData.IsAMember.push_back(false);
- for (Value mapValue : useDevOperands)
- if (auto map =
- dyn_cast_if_present<omp::MapInfoOp>(mapValue.getDefiningOp()))
- for (auto member : map.getMembers())
- if (member == mapOp)
- mapData.IsAMember.back() = true;
+ for (Value mapValue : useDevOperands) {
+ auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
+ for (auto member : map.getMembers())
+ if (member == mapOp)
+ mapData.IsAMember.back() = true;
+ }
}
}
};
@@ -2713,7 +2709,7 @@ createAlteredByCaptureMap(MapInfoData &mapData,
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
// if it's declare target, skip it, it's handled separately.
if (!mapData.IsDeclareTarget[i]) {
- auto mapOp = dyn_cast_if_present<omp::MapInfoOp>(mapData.MapClause[i]);
+ auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
omp::VariableCaptureKind captureKind =
mapOp.getMapCaptureType().value_or(omp::VariableCaptureKind::ByRef);
bool isPtrTy = checkIfPointerMap(mapOp);
@@ -2935,20 +2931,18 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
if (!info.DevicePtrInfoMap.empty()) {
builder.restoreIP(codeGenIP);
unsigned argIndex = 0;
- for (size_t i = 0; i < combinedInfo.BasePointers.size(); ++i) {
- if (combinedInfo.DevicePointers[i] ==
- llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer) {
+ for (auto [basePointer, devicePointer] : llvm::zip_equal(
+ combinedInfo.BasePointers, combinedInfo.DevicePointers)) {
+ if (devicePointer == llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer) {
const auto &arg = region.front().getArgument(argIndex);
moduleTranslation.mapValue(
- arg,
- info.DevicePtrInfoMap[combinedInfo.BasePointers[i]].second);
+ arg, info.DevicePtrInfoMap[basePointer].second);
argIndex++;
- } else if (combinedInfo.DevicePointers[i] ==
+ } else if (devicePointer ==
llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
const auto &arg = region.front().getArgument(argIndex);
auto *loadInst = builder.CreateLoad(
- builder.getPtrTy(),
- info.DevicePtrInfoMap[combinedInfo.BasePointers[i]].second);
+ builder.getPtrTy(), info.DevicePtrInfoMap[basePointer].second);
moduleTranslation.mapValue(arg, loadInst);
argIndex++;
}
@@ -2968,13 +2962,12 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
// we need to link them here before codegen.
if (ompBuilder->Config.IsTargetDevice.value_or(false)) {
unsigned argIndex = 0;
- for (size_t i = 0; i < mapData.BasePointers.size(); ++i) {
- if (mapData.DevicePointers[i] ==
- llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer ||
- mapData.DevicePointers[i] ==
- llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
+ for (auto [basePointer, devicePointer] :
+ llvm::zip_equal(mapData.BasePointers, mapData.DevicePointers)) {
+ if (devicePointer == llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer ||
+ devicePointer == llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
const auto &arg = region.front().getArgument(argIndex);
- moduleTranslation.mapValue(arg, mapData.BasePointers[i]);
+ moduleTranslation.mapValue(arg, basePointer);
argIndex++;
}
}
@@ -3198,17 +3191,14 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
llvm::IRBuilderBase::InsertPoint codeGenIP) {
builder.restoreIP(allocaIP);
- mlir::omp::VariableCaptureKind capture =
- mlir::omp::VariableCaptureKind::ByRef;
+ omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
// Find the associated MapInfoData entry for the current input
for (size_t i = 0; i < mapData.MapClause.size(); ++i)
if (mapData.OriginalValue[i] == input) {
- if (auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
- mapData.MapClause[i])) {
- capture = mapOp.getMapCaptureType().value_or(
- mlir::omp::VariableCaptureKind::ByRef);
- }
+ auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
+ capture =
+ mapOp.getMapCaptureType().value_or(omp::VariableCaptureKind::ByRef);
break;
}
@@ -3229,18 +3219,18 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
builder.restoreIP(codeGenIP);
switch (capture) {
- case mlir::omp::VariableCaptureKind::ByCopy: {
+ case omp::VariableCaptureKind::ByCopy: {
retVal = v;
break;
}
- case mlir::omp::VariableCaptureKind::ByRef: {
+ case omp::VariableCaptureKind::ByRef: {
retVal = builder.CreateAlignedLoad(
v->getType(), v,
ompBuilder.M.getDataLayout().getPrefTypeAlign(v->getType()));
break;
}
- case mlir::omp::VariableCaptureKind::This:
- case mlir::omp::VariableCaptureKind::VLAType:
+ case omp::VariableCaptureKind::This:
+ case omp::VariableCaptureKind::VLAType:
assert(false && "Currently unsupported capture kind");
break;
}
@@ -3292,8 +3282,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
builder.restoreIP(codeGenIP);
unsigned argIndex = 0;
for (auto &mapOp : mapVars) {
- auto mapInfoOp =
- mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp());
+ auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
llvm::Value *mapOpValue =
moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
const auto &arg = targetRegion.front().getArgument(argIndex);
diff --git a/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir b/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir
index bcca2af985b322..f094a46581dee0 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir
@@ -3,8 +3,20 @@
// This tests check that target code nested inside a target data region which
// has only use_device_ptr mapping corectly generates code on the device pass.
-// CHECK-NOT: call void @__tgt_target_data_begin_mapper
-// CHECK: store i32 999, ptr {{.*}}
+// CHECK: define weak_odr protected void @__omp_offloading{{.*}}main_
+// CHECK-NEXT: entry:
+// CHECK-NEXT: %[[VAL_3:.*]] = alloca ptr, align 8
+// CHECK-NEXT: store ptr %[[VAL_4:.*]], ptr %[[VAL_3]], align 8
+// CHECK-NEXT: %[[VAL_5:.*]] = call i32 @__kmpc_target_init(ptr @__omp_offloading_{{.*}}_kernel_environment, ptr %[[VAL_6:.*]])
+// CHECK-NEXT: %[[VAL_7:.*]] = icmp eq i32 %[[VAL_5]], -1
+// CHECK-NEXT: br i1 %[[VAL_7]], label %[[VAL_8:.*]], label %[[VAL_9:.*]]
+// CHECK: user_code.entry: ; preds = %[[VAL_10:.*]]
+// CHECK-NEXT: %[[VAL_11:.*]] = load ptr, ptr %[[VAL_3]], align 8
+// CHECK-NEXT: br label %[[VAL_12:.*]]
+// CHECK: omp.target: ; preds = %[[VAL_8]]
+// CHECK-NEXT: %[[VAL_13:.*]] = load ptr, ptr %[[VAL_11]], align 8
+// CHECK-NEXT: store i32 999, ptr %[[VAL_13]], align 4
+// CHECK-NEXT: br label %[[VAL_14:.*]]
module attributes {omp.is_target_device = true } {
llvm.func @_QQmain() attributes {fir.bindc_name = "main"} {
%0 = llvm.mlir.constant(1 : i64) : i64
More information about the llvm-branch-commits
mailing list