[llvm-branch-commits] [clang] [flang] [llvm] [mlir] [MLIR][OpenMP] Add LLVM translation support for OpenMP UserDefinedMappers (PR #124746)
Akash Banerjee via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jan 31 02:54:05 PST 2025
https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/124746
>From 431c404dc125aa6b27f32b6019baebf603111f51 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Tue, 28 Jan 2025 15:45:55 +0000
Subject: [PATCH 1/5] Add description for mapper_id. Add verifier check for
valid mapper_id.
---
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 2 ++
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 6 ++++++
mlir/test/Dialect/OpenMP/invalid.mlir | 10 ++++++++++
3 files changed, 18 insertions(+)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 7cc2c8fa8ce1e1..326b5d0122f3ba 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1054,6 +1054,8 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
- 'map_type': OpenMP map type for this map capture, for example: from, to and
always. It's a bitfield composed of the OpenMP runtime flags stored in
OpenMPOffloadMappingFlags.
+ - 'mapper_id': OpenMP mapper map type modifier for this map capture. It's used to
+ specify a user defined mapper to be used for mapping.
- 'map_capture_type': Capture type for the variable e.g. this, byref, byvalue, byvla
this can affect how the variable is lowered.
- `name`: Holds the name of variable as specified in user clause (including bounds).
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 92233654ba1ddc..6fc1495938fe55 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1592,6 +1592,12 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
}
+
+ if (mapInfoOp.getMapperId() &&
+ !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
+ mapInfoOp, mapInfoOp.getMapperIdAttr())) {
+ return emitError(op->getLoc(), "invalid mapper id");
+ }
} else if (!isa<DeclareMapperInfoOp>(op)) {
emitError(op->getLoc(), "map argument is not a map entry operation");
}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 1fbb4c93e855b9..93db9eb08aac98 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2661,3 +2661,13 @@ func.func @missing_workshare(%idx : index) {
}
return
}
+
+// -----
+llvm.func @invalid_mapper(%0 : !llvm.ptr) {
+ %1 = omp.map.info var_ptr(%0 : !llvm.ptr, !llvm.struct<"my_type", (i32)>) mapper(@my_mapper) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
+ // expected-error @below {{invalid mapper id}}
+ omp.target_data map_entries(%1 : !llvm.ptr) {
+ omp.terminator
+ }
+ llvm.return
+}
>From 728a8321b572553ac3fa853f785ab3e58d93e1f8 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Tue, 28 Jan 2025 15:28:29 +0000
Subject: [PATCH 2/5] Split test into two separate directives.
---
flang/test/Lower/OpenMP/map-mapper.f90 | 13 ++++++++++---
1 file changed, 10 insertions(+), 3 deletions(-)
diff --git a/flang/test/Lower/OpenMP/map-mapper.f90 b/flang/test/Lower/OpenMP/map-mapper.f90
index 856fff834643e4..0d8fe7344bfab5 100644
--- a/flang/test/Lower/OpenMP/map-mapper.f90
+++ b/flang/test/Lower/OpenMP/map-mapper.f90
@@ -13,11 +13,18 @@ program p
type(t1) :: a, b
!CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%{{.*}} : {{.*}}, {{.*}}) mapper(@_QQFxx) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "a"}
+ !CHECK: omp.target map_entries(%[[MAP_A]] -> %{{.*}}, %{{.*}} -> %{{.*}} : {{.*}}, {{.*}}) {
+ !$omp target map(mapper(xx) : a)
+ do i = 1, n
+ a%x(i) = i
+ end do
+ !$omp end target
+
!CHECK: %[[MAP_B:.*]] = omp.map.info var_ptr(%{{.*}} : {{.*}}, {{.*}}) mapper(@_QQFt1.default) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "b"}
- !CHECK: omp.target map_entries(%[[MAP_A]] -> %{{.*}}, %[[MAP_B]] -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : {{.*}}, {{.*}}, {{.*}}, {{.*}}) {
- !$omp target map(mapper(xx) : a) map(mapper(default) : b)
+ !CHECK: omp.target map_entries(%[[MAP_B]] -> %{{.*}}, %{{.*}} -> %{{.*}} : {{.*}}, {{.*}}) {
+ !$omp target map(mapper(default) : b)
do i = 1, n
- b%x(i) = a%x(i)
+ b%x(i) = i
end do
!$omp end target
end program p
>From e020b174360a645f64e2a6562eece48ce8a97482 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Tue, 28 Jan 2025 13:38:13 +0000
Subject: [PATCH 3/5] [MLIR][OpenMP] Add LLVM translation support for OpenMP
UserDefinedMappers
This patch adds OpenMPToLLVMIRTranslation support for the OpenMP Declare Mapper directive.
Since both MLIR and Clang now support custom mappers, I've made the relative params required instead of optional as well.
Depends on #121005
---
clang/lib/CodeGen/CGOpenMPRuntime.cpp | 20 +-
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 49 +++--
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 78 +++----
.../Frontend/OpenMPIRBuilderTest.cpp | 56 +++--
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 198 ++++++++++++++----
mlir/test/Target/LLVMIR/omptarget-llvm.mlir | 117 +++++++++++
.../fortran/target-custom-mapper.f90 | 46 ++++
7 files changed, 443 insertions(+), 121 deletions(-)
create mode 100644 offload/test/offloading/fortran/target-custom-mapper.f90
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 30c3834de139c3..0a13581dcb1700 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -32,10 +32,12 @@
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/AtomicOrdering.h"
+#include "llvm/Support/Error.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cstdint>
@@ -8888,8 +8890,8 @@ static void emitOffloadingArraysAndArgs(
return MFunc;
};
OMPBuilder.emitOffloadingArraysAndArgs(
- AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, IsNonContiguous,
- ForEndCall, DeviceAddrCB, CustomMapperCB);
+ AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, CustomMapperCB,
+ IsNonContiguous, ForEndCall, DeviceAddrCB);
}
/// Check for inner distribute directive.
@@ -9098,9 +9100,10 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D,
CGM.getCXXABI().getMangleContext().mangleCanonicalTypeName(Ty, Out);
std::string Name = getName({"omp_mapper", TyStr, D->getName()});
- auto *NewFn = OMPBuilder.emitUserDefinedMapper(PrivatizeAndGenMapInfoCB,
- ElemTy, Name, CustomMapperCB);
- UDMMap.try_emplace(D, NewFn);
+ llvm::Expected<llvm::Function *> NewFn = OMPBuilder.emitUserDefinedMapper(
+ PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB);
+ assert(NewFn && "Unexpected error in emitUserDefinedMapper");
+ UDMMap.try_emplace(D, *NewFn);
if (CGF)
FunctionUDMMap[CGF->CurFn].push_back(D);
}
@@ -10092,9 +10095,10 @@ void CGOpenMPRuntime::emitTargetDataCalls(
CGF.Builder.GetInsertPoint());
llvm::OpenMPIRBuilder::LocationDescription OmpLoc(CodeGenIP);
llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
- OMPBuilder.createTargetData(
- OmpLoc, AllocaIP, CodeGenIP, DeviceID, IfCondVal, Info, GenMapInfoCB,
- /*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, CustomMapperCB, RTLoc);
+ OMPBuilder.createTargetData(OmpLoc, AllocaIP, CodeGenIP, DeviceID,
+ IfCondVal, Info, GenMapInfoCB, CustomMapperCB,
+ /*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB,
+ RTLoc);
assert(AfterIP && "unexpected error creating target data");
CGF.Builder.restoreIP(*AfterIP);
}
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 4ce47b1c05d9b0..4e80bff6db4553 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -22,6 +22,7 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Allocator.h"
+#include "llvm/Support/Error.h"
#include "llvm/TargetParser/Triple.h"
#include <forward_list>
#include <map>
@@ -2355,6 +2356,7 @@ class OpenMPIRBuilder {
CurInfo.NonContigInfo.Strides.end());
}
};
+ using MapInfosOrErrorTy = Expected<MapInfosTy &>;
/// Callback function type for functions emitting the host fallback code that
/// is executed when the kernel launch fails. It takes an insertion point as
@@ -2431,9 +2433,9 @@ class OpenMPIRBuilder {
/// including base pointers, pointers, sizes, map types, user-defined mappers.
void emitOffloadingArrays(
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
- TargetDataInfo &Info, bool IsNonContiguous = false,
- function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
- function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
+ TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
+ bool IsNonContiguous = false,
+ function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
/// Allocates memory for and populates the arrays required for offloading
/// (offload_{baseptrs|ptrs|mappers|sizes|maptypes|mapnames}). Then, it
@@ -2444,9 +2446,9 @@ class OpenMPIRBuilder {
void emitOffloadingArraysAndArgs(
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
+ function_ref<Value *(unsigned int)> CustomMapperCB,
bool IsNonContiguous = false, bool ForEndCall = false,
- function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
- function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
+ function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
/// Creates offloading entry for the provided entry ID \a ID, address \a
/// Addr, size \a Size, and flags \a Flags.
@@ -2911,12 +2913,12 @@ class OpenMPIRBuilder {
/// \param FuncName Optional param to specify mapper function name.
/// \param CustomMapperCB Optional callback to generate code related to
/// custom mappers.
- Function *emitUserDefinedMapper(
- function_ref<MapInfosTy &(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
- llvm::Value *BeginArg)>
+ Expected<Function *> emitUserDefinedMapper(
+ function_ref<MapInfosOrErrorTy(
+ InsertPointTy CodeGenIP, llvm::Value *PtrPHI, llvm::Value *BeginArg)>
PrivAndGenMapInfoCB,
llvm::Type *ElemTy, StringRef FuncName,
- function_ref<bool(unsigned int, Function **)> CustomMapperCB = nullptr);
+ function_ref<bool(unsigned int, Function **)> CustomMapperCB);
/// Generator for '#omp target data'
///
@@ -2930,21 +2932,21 @@ class OpenMPIRBuilder {
/// \param IfCond Value which corresponds to the if clause condition.
/// \param Info Stores all information realted to the Target Data directive.
/// \param GenMapInfoCB Callback that populates the MapInfos and returns.
+ /// \param CustomMapperCB Callback to generate code related to
+ /// custom mappers.
/// \param BodyGenCB Optional Callback to generate the region code.
/// \param DeviceAddrCB Optional callback to generate code related to
/// use_device_ptr and use_device_addr.
- /// \param CustomMapperCB Optional callback to generate code related to
- /// custom mappers.
InsertPointOrErrorTy createTargetData(
const LocationDescription &Loc, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
+ function_ref<Value *(unsigned int)> CustomMapperCB,
omp::RuntimeFunction *MapperFunc = nullptr,
function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
BodyGenTy BodyGenType)>
BodyGenCB = nullptr,
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
- function_ref<Value *(unsigned int)> CustomMapperCB = nullptr,
Value *SrcLocInfo = nullptr);
using TargetBodyGenCallbackTy = function_ref<InsertPointOrErrorTy(
@@ -2960,6 +2962,7 @@ class OpenMPIRBuilder {
/// \param IsOffloadEntry whether it is an offload entry.
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
+ /// \param Info Stores all information realted to the Target directive.
/// \param EntryInfo The entry information about the function.
/// \param NumTeams Number of teams specified in the num_teams clause.
/// \param NumThreads Number of teams specified in the thread_limit clause.
@@ -2968,18 +2971,22 @@ class OpenMPIRBuilder {
/// \param BodyGenCB Callback that will generate the region code.
/// \param ArgAccessorFuncCB Callback that will generate accessors
/// instructions for passed in target arguments where neccessary
+ /// \param CustomMapperCB Callback to generate code related to
+ /// custom mappers.
/// \param Dependencies A vector of DependData objects that carry
// dependency information as passed in the depend clause
// \param HasNowait Whether the target construct has a `nowait` clause or not.
- InsertPointOrErrorTy createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry,
- OpenMPIRBuilder::InsertPointTy AllocaIP,
- OpenMPIRBuilder::InsertPointTy CodeGenIP,
- TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
- ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
- GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
- TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
- SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
+ InsertPointOrErrorTy
+ createTarget(const LocationDescription &Loc, bool IsOffloadEntry,
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetDataInfo &Info,
+ TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
+ ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
+ GenMapInfoCallbackTy GenMapInfoCB,
+ TargetBodyGenCallbackTy BodyGenCB,
+ TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
+ function_ref<Value *(unsigned int)> CustomMapperCB,
+ SmallVector<DependData> Dependencies, bool HasNowait);
/// Returns __kmpc_for_static_init_* runtime function for the specified
/// size \a IVSize and sign \a IVSigned. Will create a distribute call
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 0d8dbbe3a8a718..be53dbbf8addf3 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -47,6 +47,7 @@
#include "llvm/IR/Value.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Target/TargetMachine.h"
@@ -6480,12 +6481,12 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
const LocationDescription &Loc, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
+ function_ref<Value *(unsigned int)> CustomMapperCB,
omp::RuntimeFunction *MapperFunc,
function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
BodyGenTy BodyGenType)>
BodyGenCB,
- function_ref<void(unsigned int, Value *)> DeviceAddrCB,
- function_ref<Value *(unsigned int)> CustomMapperCB, Value *SrcLocInfo) {
+ function_ref<void(unsigned int, Value *)> DeviceAddrCB, Value *SrcLocInfo) {
if (!updateToLocation(Loc))
return InsertPointTy();
@@ -6511,8 +6512,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
InsertPointTy CodeGenIP) -> Error {
MapInfo = &GenMapInfoCB(Builder.saveIP());
emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info,
- /*IsNonContiguous=*/true, DeviceAddrCB,
- CustomMapperCB);
+ CustomMapperCB,
+ /*IsNonContiguous=*/true, DeviceAddrCB);
TargetDataRTArgs RTArgs;
emitOffloadingArraysArgument(Builder, RTArgs, Info);
@@ -7304,22 +7305,24 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
- TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, bool IsNonContiguous,
- bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB,
- function_ref<Value *(unsigned int)> CustomMapperCB) {
- emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, IsNonContiguous,
- DeviceAddrCB, CustomMapperCB);
+ TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
+ function_ref<Value *(unsigned int)> CustomMapperCB, bool IsNonContiguous,
+ bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
+ emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, CustomMapperCB,
+ IsNonContiguous, DeviceAddrCB);
emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
}
static void
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
- OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
+ OpenMPIRBuilder::TargetDataInfo &Info, Function *OutlinedFn,
Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
- SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
- bool HasNoWait = false) {
+ function_ref<Value *(unsigned int)> CustomMapperCB,
+ SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies,
+ bool HasNoWait) {
// Generate a function call to the host fallback implementation of the target
// region. This is called by the host when no offload entry was generated for
// the target region and when the offloading call fails at runtime.
@@ -7384,14 +7387,10 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
return;
}
- OpenMPIRBuilder::TargetDataInfo Info(
- /*RequiresDevicePointerInfo=*/false,
- /*SeparateBeginEndCalls=*/true);
-
OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
- RTArgs, MapInfo,
+ RTArgs, MapInfo, CustomMapperCB,
/*IsNonContiguous=*/true,
/*ForEndCall=*/false);
@@ -7439,11 +7438,13 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
- InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
- ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
- SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
+ InsertPointTy CodeGenIP, TargetDataInfo &Info,
+ TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
+ ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
+ GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
+ function_ref<Value *(unsigned int)> CustomMapperCB,
SmallVector<DependData> Dependencies, bool HasNowait) {
if (!updateToLocation(Loc))
@@ -7458,15 +7459,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// and ArgAccessorFuncCB
if (Error Err = emitTargetOutlinedFunction(
*this, Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID,
- Args, CBFunc, ArgAccessorFuncCB))
+ Inputs, CBFunc, ArgAccessorFuncCB))
return Err;
// If we are not on the target device, then we need to generate code
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
- emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
- NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
+ emitTargetCall(*this, Builder, AllocaIP, Info, OutlinedFn, OutlinedFnID,
+ NumTeams, NumThreads, Inputs, GenMapInfoCB, CustomMapperCB,
+ Dependencies, HasNowait);
return Builder.saveIP();
}
@@ -7791,9 +7793,9 @@ void OpenMPIRBuilder::emitUDMapperArrayInitOrDel(
OffloadingArgs);
}
-Function *OpenMPIRBuilder::emitUserDefinedMapper(
- function_ref<MapInfosTy &(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
- llvm::Value *BeginArg)>
+Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
+ function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
+ llvm::Value *BeginArg)>
GenMapInfoCB,
Type *ElemTy, StringRef FuncName,
function_ref<bool(unsigned int, Function **)> CustomMapperCB) {
@@ -7867,7 +7869,9 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
PtrPHI->addIncoming(PtrBegin, HeadBB);
// Get map clause information. Fill up the arrays with all mapped variables.
- MapInfosTy &Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
+ MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
+ if (!Info)
+ return Info.takeError();
// Call the runtime API __tgt_mapper_num_components to get the number of
// pre-existing components.
@@ -7879,20 +7883,20 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
Builder.CreateShl(PreviousSize, Builder.getInt64(getFlagMemberOffset()));
// Fill up the runtime mapper handle for all components.
- for (unsigned I = 0; I < Info.BasePointers.size(); ++I) {
+ for (unsigned I = 0; I < Info->BasePointers.size(); ++I) {
Value *CurBaseArg =
- Builder.CreateBitCast(Info.BasePointers[I], Builder.getPtrTy());
+ Builder.CreateBitCast(Info->BasePointers[I], Builder.getPtrTy());
Value *CurBeginArg =
- Builder.CreateBitCast(Info.Pointers[I], Builder.getPtrTy());
- Value *CurSizeArg = Info.Sizes[I];
- Value *CurNameArg = Info.Names.size()
- ? Info.Names[I]
+ Builder.CreateBitCast(Info->Pointers[I], Builder.getPtrTy());
+ Value *CurSizeArg = Info->Sizes[I];
+ Value *CurNameArg = Info->Names.size()
+ ? Info->Names[I]
: Constant::getNullValue(Builder.getPtrTy());
// Extract the MEMBER_OF field from the map type.
Value *OriMapType = Builder.getInt64(
static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
- Info.Types[I]));
+ Info->Types[I]));
Value *MemberMapType =
Builder.CreateNUWAdd(OriMapType, ShiftedPreviousSize);
@@ -8013,9 +8017,9 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
void OpenMPIRBuilder::emitOffloadingArrays(
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
- TargetDataInfo &Info, bool IsNonContiguous,
- function_ref<void(unsigned int, Value *)> DeviceAddrCB,
- function_ref<Value *(unsigned int)> CustomMapperCB) {
+ TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
+ bool IsNonContiguous,
+ function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
// Reset the array information.
Info.clearArrayInfo();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index d7ac1082491180..a33e1533dede43 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5876,6 +5876,7 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) {
return CombinedInfo;
};
+ auto CustomMapperCB = [&](unsigned int I) { return nullptr; };
llvm::OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/false,
/*SeparateBeginEndCalls=*/true);
@@ -5885,7 +5886,7 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) {
llvm::omp::RuntimeFunction RTLFunc = OMPRTL___tgt_target_data_begin_mapper;
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTargetData(
Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID),
- /* IfCond= */ nullptr, Info, GenMapInfoCB, &RTLFunc);
+ /* IfCond= */ nullptr, Info, GenMapInfoCB, CustomMapperCB, &RTLFunc);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
@@ -5937,6 +5938,7 @@ TEST_F(OpenMPIRBuilderTest, TargetExitData) {
return CombinedInfo;
};
+ auto CustomMapperCB = [&](unsigned int I) { return nullptr; };
llvm::OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/false,
/*SeparateBeginEndCalls=*/true);
@@ -5946,7 +5948,7 @@ TEST_F(OpenMPIRBuilderTest, TargetExitData) {
llvm::omp::RuntimeFunction RTLFunc = OMPRTL___tgt_target_data_end_mapper;
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTargetData(
Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID),
- /* IfCond= */ nullptr, Info, GenMapInfoCB, &RTLFunc);
+ /* IfCond= */ nullptr, Info, GenMapInfoCB, CustomMapperCB, &RTLFunc);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
@@ -6020,6 +6022,7 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) {
return CombinedInfo;
};
+ auto CustomMapperCB = [&](unsigned int I) { return nullptr; };
llvm::OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/true,
/*SeparateBeginEndCalls=*/true);
@@ -6055,9 +6058,10 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) {
};
OpenMPIRBuilder::InsertPointOrErrorTy TargetDataIP1 =
- OMPBuilder.createTargetData(
- Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID),
- /* IfCond= */ nullptr, Info, GenMapInfoCB, nullptr, BodyCB);
+ OMPBuilder.createTargetData(Loc, AllocaIP, Builder.saveIP(),
+ Builder.getInt64(DeviceID),
+ /* IfCond= */ nullptr, Info, GenMapInfoCB,
+ CustomMapperCB, nullptr, BodyCB);
assert(TargetDataIP1 && "unexpected error");
Builder.restoreIP(*TargetDataIP1);
@@ -6083,9 +6087,10 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) {
return Builder.saveIP();
};
OpenMPIRBuilder::InsertPointOrErrorTy TargetDataIP2 =
- OMPBuilder.createTargetData(
- Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID),
- /* IfCond= */ nullptr, Info, GenMapInfoCB, nullptr, BodyTargetCB);
+ OMPBuilder.createTargetData(Loc, AllocaIP, Builder.saveIP(),
+ Builder.getInt64(DeviceID),
+ /* IfCond= */ nullptr, Info, GenMapInfoCB,
+ CustomMapperCB, nullptr, BodyTargetCB);
assert(TargetDataIP2 && "unexpected error");
Builder.restoreIP(*TargetDataIP2);
EXPECT_TRUE(CheckDevicePassBodyGen);
@@ -6180,11 +6185,16 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
return CombinedInfos;
};
+ auto CustomMapperCB = [&](unsigned int I) { return nullptr; };
+ llvm::OpenMPIRBuilder::TargetDataInfo Info(/*RequiresDevicePointerInfo=*/true,
+ /*SeparateBeginEndCalls=*/true);
+
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(),
- EntryInfo, -1, 0, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(), Info,
+ EntryInfo, -1, 0, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB,
+ CustomMapperCB, {}, false);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
OMPBuilder.finalize();
@@ -6278,6 +6288,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
return CombinedInfos;
};
+ auto CustomMapperCB = [&](unsigned int I) { return nullptr; };
auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP)
-> OpenMPIRBuilder::InsertPointTy {
@@ -6291,12 +6302,14 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
F->getEntryBlock().getFirstInsertionPt());
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
/*Line=*/3, /*Count=*/0);
+ llvm::OpenMPIRBuilder::TargetDataInfo Info(/*RequiresDevicePointerInfo=*/true,
+ /*SeparateBeginEndCalls=*/true);
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
- OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, /*NumTeams=*/-1,
- /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
- BodyGenCB, SimpleArgAccessorCB);
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, Info, EntryInfo,
+ /*NumTeams=*/-1,
+ /*NumThreads=*/0, CapturedArgs, GenMapInfoCB, BodyGenCB,
+ SimpleArgAccessorCB, CustomMapperCB, {}, false);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
@@ -6432,6 +6445,7 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
return CombinedInfos;
};
+ auto CustomMapperCB = [&](unsigned int I) { return nullptr; };
llvm::Value *RaiseAlloca = nullptr;
auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
@@ -6448,12 +6462,14 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
F->getEntryBlock().getFirstInsertionPt());
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
/*Line=*/3, /*Count=*/0);
+ llvm::OpenMPIRBuilder::TargetDataInfo Info(/*RequiresDevicePointerInfo=*/true,
+ /*SeparateBeginEndCalls=*/true);
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
- OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, /*NumTeams=*/-1,
- /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
- BodyGenCB, SimpleArgAccessorCB);
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, Info, EntryInfo,
+ /*NumTeams=*/-1,
+ /*NumThreads=*/0, CapturedArgs, GenMapInfoCB, BodyGenCB,
+ SimpleArgAccessorCB, CustomMapperCB, {}, false);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d591c98a5497f8..024cc15518b5ac 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -28,8 +28,10 @@
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
#include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/ReplaceConstant.h"
+#include "llvm/Support/Error.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -2709,13 +2711,23 @@ getRefPtrIfDeclareTarget(mlir::Value value,
}
namespace {
+// Append customMappers information to existing MapInfosTy
+struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
+ SmallVector<Operation *, 4> Mappers;
+
+ /// Append arrays in \a CurInfo.
+ void append(MapInfosTy &curInfo) {
+ Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end());
+ llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo);
+ }
+};
// A small helper structure to contain data gathered
// for map lowering and coalese it into one area and
// avoiding extra computations such as searches in the
// llvm module for lowered mapped variables or checking
// if something is declare target (and retrieving the
// value) more than neccessary.
-struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
+struct MapInfoData : MapInfosTy {
llvm::SmallVector<bool, 4> IsDeclareTarget;
llvm::SmallVector<bool, 4> IsAMember;
// Identify if mapping was added by mapClause or use_device clauses.
@@ -2734,7 +2746,7 @@ struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
OriginalValue.append(CurInfo.OriginalValue.begin(),
CurInfo.OriginalValue.end());
BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end());
- llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo);
+ MapInfosTy::append(CurInfo);
}
};
} // namespace
@@ -2855,6 +2867,12 @@ static void collectMapDataFromMapOperands(
mapData.Names.push_back(LLVM::createMappingInformation(
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
+ if (mapOp.getMapperId())
+ mapData.Mappers.push_back(
+ SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
+ mapOp, mapOp.getMapperIdAttr()));
+ else
+ mapData.Mappers.push_back(nullptr);
mapData.IsAMapping.push_back(true);
mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
}
@@ -2899,6 +2917,7 @@ static void collectMapDataFromMapOperands(
mapData.Names.push_back(LLVM::createMappingInformation(
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
mapData.DevicePointers.push_back(devInfoTy);
+ mapData.Mappers.push_back(nullptr);
mapData.IsAMapping.push_back(false);
mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
}
@@ -3064,9 +3083,8 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
// inside of CGOpenMPRuntime.cpp
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
- llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
- llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
- uint64_t mapDataIndex, bool isTargetParams) {
+ llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
+ MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams) {
// Map the first segment of our structure
combinedInfo.Types.emplace_back(
isTargetParams
@@ -3074,6 +3092,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
combinedInfo.DevicePointers.emplace_back(
mapData.DevicePointers[mapDataIndex]);
+ combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
@@ -3137,6 +3156,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.DevicePointers.emplace_back(
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
+ combinedInfo.Mappers.emplace_back(nullptr);
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
@@ -3170,9 +3190,9 @@ static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
// This function is intended to add explicit mappings of members
static void processMapMembersWithParent(
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
- llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
- llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
- uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
+ llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
+ MapInfoData &mapData, uint64_t mapDataIndex,
+ llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
auto parentClause =
llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
@@ -3200,6 +3220,7 @@ static void processMapMembersWithParent(
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.DevicePointers.emplace_back(
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
+ combinedInfo.Mappers.emplace_back(nullptr);
combinedInfo.Names.emplace_back(
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(
@@ -3222,6 +3243,7 @@ static void processMapMembersWithParent(
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.DevicePointers.emplace_back(
mapData.DevicePointers[memberDataIdx]);
+ combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
combinedInfo.Names.emplace_back(
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
uint64_t basePointerIndex =
@@ -3233,10 +3255,9 @@ static void processMapMembersWithParent(
}
}
-static void
-processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
- llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
- bool isTargetParams, int mapDataParentIdx = -1) {
+static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
+ MapInfosTy &combinedInfo, bool isTargetParams,
+ int mapDataParentIdx = -1) {
// Declare Target Mappings are excluded from being marked as
// OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
// marked with OMP_MAP_PTR_AND_OBJ instead.
@@ -3266,16 +3287,18 @@ processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
+ combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]);
combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
}
-static void processMapWithMembersOf(
- LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
- llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
- llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
- uint64_t mapDataIndex, bool isTargetParams) {
+static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation,
+ llvm::IRBuilderBase &builder,
+ llvm::OpenMPIRBuilder &ompBuilder,
+ DataLayout &dl, MapInfosTy &combinedInfo,
+ MapInfoData &mapData, uint64_t mapDataIndex,
+ bool isTargetParams) {
auto parentClause =
llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
@@ -3380,8 +3403,7 @@ createAlteredByCaptureMap(MapInfoData &mapData,
// Generate all map related information and fill the combinedInfo.
static void genMapInfos(llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
- DataLayout &dl,
- llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
+ DataLayout &dl, MapInfosTy &combinedInfo,
MapInfoData &mapData, bool isTargetParams = false) {
// We wish to modify some of the methods in which arguments are
// passed based on their capture type by the target region, this can
@@ -3421,6 +3443,85 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
}
}
+static llvm::Expected<llvm::Function *>
+emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation);
+
+static llvm::Expected<llvm::Function *>
+getOrCreateUserDefinedMapperFunc(Operation *declMapperOp,
+ llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::DenseMap<const Operation *, llvm::Function *> userDefMapperMap;
+ auto iter = userDefMapperMap.find(declMapperOp);
+ if (iter != userDefMapperMap.end())
+ return iter->second;
+ llvm::Expected<llvm::Function *> mapperFunc =
+ emitUserDefinedMapper(declMapperOp, builder, moduleTranslation);
+ if (!mapperFunc)
+ return mapperFunc.takeError();
+ userDefMapperMap.try_emplace(declMapperOp, *mapperFunc);
+ return userDefMapperMap.lookup(declMapperOp);
+}
+
+static llvm::Expected<llvm::Function *>
+emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ auto declMapperOp = cast<omp::DeclareMapperOp>(op);
+ auto declMapperInfoOp =
+ *declMapperOp.getOps<omp::DeclareMapperInfoOp>().begin();
+ DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>());
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ llvm::Type *varType =
+ moduleTranslation.convertType(declMapperOp.getVarType());
+ std::string mapperName = ompBuilder->createPlatformSpecificName(
+ {"omp_mapper", declMapperOp.getSymName()});
+ SmallVector<Value> mapVars = declMapperInfoOp.getMapVars();
+
+ using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
+
+ // Fill up the arrays with all the mapped variables.
+ MapInfosTy combinedInfo;
+ auto genMapInfoCB =
+ [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
+ llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
+ builder.restoreIP(codeGenIP);
+ moduleTranslation.mapValue(declMapperOp.getRegion().getArgument(0), ptrPHI);
+ moduleTranslation.mapBlock(&declMapperOp.getRegion().front(),
+ builder.GetInsertBlock());
+ if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(),
+ /*ignoreArguments=*/true,
+ builder)))
+ return llvm::make_error<PreviouslyReportedError>();
+ MapInfoData mapData;
+ collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
+ builder);
+ genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData);
+
+ // Drop the mapping that is no longer necessary so that the same region can
+ // be processed multiple times.
+ moduleTranslation.forgetMapping(declMapperOp.getRegion());
+ return combinedInfo;
+ };
+
+ auto customMapperCB = [&](unsigned i, llvm::Function **mapperFunc) {
+ if (combinedInfo.Mappers[i]) {
+ // Call the corresponding mapper function.
+ llvm::Expected<llvm::Function *> newFn = getOrCreateUserDefinedMapperFunc(
+ combinedInfo.Mappers[i], builder, moduleTranslation);
+ assert(newFn && "Expect a valid mapper function is available");
+ *mapperFunc = *newFn;
+ return true;
+ }
+ return false;
+ };
+
+ llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
+ genMapInfoCB, varType, mapperName, customMapperCB);
+ if (!newFn)
+ return newFn.takeError();
+ return *newFn;
+}
+
static LogicalResult
convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
@@ -3532,9 +3633,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
builder, useDevicePtrVars, useDeviceAddrVars);
// Fill up the arrays with all the mapped variables.
- llvm::OpenMPIRBuilder::MapInfosTy combinedInfo;
- auto genMapInfoCB =
- [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & {
+ MapInfosTy combinedInfo;
+ auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
builder.restoreIP(codeGenIP);
genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
return combinedInfo;
@@ -3577,6 +3677,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
-> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
+ builder.restoreIP(codeGenIP);
assert(isa<omp::TargetDataOp>(op) &&
"BodyGen requested for non TargetDataOp");
auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
@@ -3585,8 +3686,6 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
case BodyGenTy::Priv:
// Check if any device ptr/addr info is available
if (!info.DevicePtrInfoMap.empty()) {
- builder.restoreIP(codeGenIP);
-
mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
blockArgIface.getUseDeviceAddrBlockArgs(),
useDeviceAddrVars, mapData,
@@ -3613,7 +3712,6 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
case BodyGenTy::NoPriv:
// If device info is available then region has already been generated
if (info.DevicePtrInfoMap.empty()) {
- builder.restoreIP(codeGenIP);
// For device pass, if use_device_ptr(addr) mappings were present,
// we need to link them here before codegen.
if (ompBuilder->Config.IsTargetDevice.value_or(false)) {
@@ -3634,6 +3732,18 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
return builder.saveIP();
};
+ auto customMapperCB = [&](unsigned int i) {
+ llvm::Function *mapperFunc = nullptr;
+ if (combinedInfo.Mappers[i]) {
+ info.HasMapper = true;
+ llvm::Expected<llvm::Function *> newFn = getOrCreateUserDefinedMapperFunc(
+ combinedInfo.Mappers[i], builder, moduleTranslation);
+ assert(newFn && "Expect a valid mapper function is available");
+ mapperFunc = *newFn;
+ }
+ return mapperFunc;
+ };
+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
@@ -3641,10 +3751,11 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
if (isa<omp::TargetDataOp>(op))
return ompBuilder->createTargetData(
ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID),
- ifCond, info, genMapInfoCB, nullptr, bodyGenCB);
- return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(),
- builder.getInt64(deviceID), ifCond,
- info, genMapInfoCB, &RTLFn);
+ ifCond, info, genMapInfoCB, customMapperCB, nullptr, bodyGenCB,
+ /*DeviceAddrCB=*/nullptr);
+ return ompBuilder->createTargetData(
+ ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond,
+ info, genMapInfoCB, customMapperCB, &RTLFn);
}();
if (failed(handleError(afterIP, *op)))
@@ -4032,9 +4143,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
builder);
- llvm::OpenMPIRBuilder::MapInfosTy combinedInfos;
- auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
- -> llvm::OpenMPIRBuilder::MapInfosTy & {
+ MapInfosTy combinedInfos;
+ auto genMapInfoCB =
+ [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
builder.restoreIP(codeGenIP);
genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, true);
return combinedInfos;
@@ -4079,11 +4190,27 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::OpenMPIRBuilder::TargetDataInfo info(
+ /*RequiresDevicePointerInfo=*/false,
+ /*SeparateBeginEndCalls=*/true);
+
+ auto customMapperCB = [&](unsigned int i) {
+ llvm::Value *mapperFunc = nullptr;
+ if (combinedInfos.Mappers[i]) {
+ info.HasMapper = true;
+ llvm::Expected<llvm::Function *> newFn = getOrCreateUserDefinedMapperFunc(
+ combinedInfos.Mappers[i], builder, moduleTranslation);
+ assert(newFn && "Expect a valid mapper function is available");
+ mapperFunc = *newFn;
+ }
+ return mapperFunc;
+ };
+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
- ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
+ ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB,
- argAccessorCB, dds, targetOp.getNowait());
+ argAccessorCB, customMapperCB, dds, targetOp.getNowait());
if (failed(handleError(afterIP, opInst)))
return failure();
@@ -4302,7 +4429,8 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
.Case([&](omp::TaskwaitOp op) {
return convertOmpTaskwaitOp(op, builder, moduleTranslation);
})
- .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareReductionOp,
+ .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
+ omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
omp::CriticalDeclareOp>([](auto op) {
// `yield` and `terminator` can be just omitted. The block structure
// was created in the region that handles their parent operation.
diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
index 7f21095763a397..fcbc57f67ae1b9 100644
--- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
@@ -485,3 +485,120 @@ llvm.func @_QPopenmp_target_data_update() {
// CHECK: call void @__tgt_target_data_update_mapper(ptr @2, i64 -1, i32 1, ptr %[[BASEPTRS_VAL_2]], ptr %[[PTRS_VAL_2]], ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr null)
// CHECK: ret void
+
+// -----
+
+omp.declare_mapper @_QQFmy_testmy_mapper : !llvm.struct<"_QFmy_testTmy_type", (i32)> {
+^bb0(%arg0: !llvm.ptr):
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ %1 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)>
+ %2 = omp.map.info var_ptr(%1 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "var%data"}
+ %3 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)>) map_clauses(tofrom) capture(ByRef) members(%2 : [0] : !llvm.ptr) -> !llvm.ptr {name = "var", partial_map = true}
+ omp.declare_mapper_info map_entries(%3, %2 : !llvm.ptr, !llvm.ptr)
+}
+
+llvm.func @_QPopenmp_target_data_mapper() {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x !llvm.struct<"_QFmy_testTmy_type", (i32)> {bindc_name = "a"} : (i64) -> !llvm.ptr
+ %2 = omp.map.info var_ptr(%1 : !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)>) mapper(@_QQFmy_testmy_mapper) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "a"}
+ omp.target_data map_entries(%2 : !llvm.ptr) {
+ %3 = llvm.mlir.constant(10 : i32) : i32
+ %4 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)>
+ llvm.store %3, %4 : i32, !llvm.ptr
+ omp.terminator
+ }
+ llvm.return
+}
+
+// CHECK: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 4]
+// CHECK: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 3]
+// CHECK-LABEL: define void @_QPopenmp_target_data_mapper
+// CHECK: %[[VAL_0:.*]] = alloca [1 x ptr], align 8
+// CHECK: %[[VAL_1:.*]] = alloca [1 x ptr], align 8
+// CHECK: %[[VAL_2:.*]] = alloca [1 x ptr], align 8
+// CHECK: %[[VAL_3:.*]] = alloca %[[VAL_4:.*]], i64 1, align 8
+// CHECK: br label %[[VAL_5:.*]]
+// CHECK: entry: ; preds = %[[VAL_6:.*]]
+// CHECK: %[[VAL_7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0
+// CHECK: store ptr %[[VAL_3]], ptr %[[VAL_7]], align 8
+// CHECK: %[[VAL_8:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0
+// CHECK: store ptr %[[VAL_3]], ptr %[[VAL_8]], align 8
+// CHECK: %[[VAL_9:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_2]], i64 0, i64 0
+// CHECK: store ptr @.omp_mapper._QQFmy_testmy_mapper, ptr %[[VAL_9]], align 8
+// CHECK: %[[VAL_10:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0
+// CHECK: %[[VAL_11:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr @4, i64 -1, i32 1, ptr %[[VAL_10]], ptr %[[VAL_11]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr %[[VAL_2]])
+// CHECK: %[[VAL_12:.*]] = getelementptr %[[VAL_4]], ptr %[[VAL_3]], i32 0, i32 0
+// CHECK: store i32 10, ptr %[[VAL_12]], align 4
+// CHECK: %[[VAL_13:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0
+// CHECK: %[[VAL_14:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0
+// CHECK: call void @__tgt_target_data_end_mapper(ptr @4, i64 -1, i32 1, ptr %[[VAL_13]], ptr %[[VAL_14]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr %[[VAL_2]])
+// CHECK: ret void
+
+// CHECK-LABEL: define internal void @.omp_mapper._QQFmy_testmy_mapper
+// CHECK: entry:
+// CHECK: %[[VAL_15:.*]] = udiv exact i64 %[[VAL_16:.*]], 4
+// CHECK: %[[VAL_17:.*]] = getelementptr %[[VAL_18:.*]], ptr %[[VAL_19:.*]], i64 %[[VAL_15]]
+// CHECK: %[[VAL_20:.*]] = icmp sgt i64 %[[VAL_15]], 1
+// CHECK: %[[VAL_21:.*]] = and i64 %[[VAL_22:.*]], 8
+// CHECK: %[[VAL_23:.*]] = icmp ne ptr %[[VAL_24:.*]], %[[VAL_19]]
+// CHECK: %[[VAL_25:.*]] = and i64 %[[VAL_22]], 16
+// CHECK: %[[VAL_26:.*]] = icmp ne i64 %[[VAL_25]], 0
+// CHECK: %[[VAL_27:.*]] = and i1 %[[VAL_23]], %[[VAL_26]]
+// CHECK: %[[VAL_28:.*]] = or i1 %[[VAL_20]], %[[VAL_27]]
+// CHECK: %[[VAL_29:.*]] = icmp eq i64 %[[VAL_21]], 0
+// CHECK: %[[VAL_30:.*]] = and i1 %[[VAL_28]], %[[VAL_29]]
+// CHECK: br i1 %[[VAL_30]], label %[[VAL_31:.*]], label %[[VAL_32:.*]]
+// CHECK: .omp.array..init: ; preds = %[[VAL_33:.*]]
+// CHECK: %[[VAL_34:.*]] = mul nuw i64 %[[VAL_15]], 4
+// CHECK: %[[VAL_35:.*]] = and i64 %[[VAL_22]], -4
+// CHECK: %[[VAL_36:.*]] = or i64 %[[VAL_35]], 512
+// CHECK: call void @__tgt_push_mapper_component(ptr %[[VAL_37:.*]], ptr %[[VAL_24]], ptr %[[VAL_19]], i64 %[[VAL_34]], i64 %[[VAL_36]], ptr %[[VAL_38:.*]])
+// CHECK: br label %[[VAL_32]]
+// CHECK: omp.arraymap.head: ; preds = %[[VAL_31]], %[[VAL_33]]
+// CHECK: %[[VAL_39:.*]] = icmp eq ptr %[[VAL_19]], %[[VAL_17]]
+// CHECK: br i1 %[[VAL_39]], label %[[VAL_40:.*]], label %[[VAL_41:.*]]
+// CHECK: omp.arraymap.body: ; preds = %[[VAL_42:.*]], %[[VAL_32]]
+// CHECK: %[[VAL_43:.*]] = phi ptr [ %[[VAL_19]], %[[VAL_32]] ], [ %[[VAL_44:.*]], %[[VAL_42]] ]
+// CHECK: %[[VAL_45:.*]] = getelementptr %[[VAL_18]], ptr %[[VAL_43]], i32 0, i32 0
+// CHECK: %[[VAL_46:.*]] = call i64 @__tgt_mapper_num_components(ptr %[[VAL_37]])
+// CHECK: %[[VAL_47:.*]] = shl i64 %[[VAL_46]], 48
+// CHECK: %[[VAL_48:.*]] = add nuw i64 3, %[[VAL_47]]
+// CHECK: %[[VAL_49:.*]] = and i64 %[[VAL_22]], 3
+// CHECK: %[[VAL_50:.*]] = icmp eq i64 %[[VAL_49]], 0
+// CHECK: br i1 %[[VAL_50]], label %[[VAL_51:.*]], label %[[VAL_52:.*]]
+// CHECK: omp.type.alloc: ; preds = %[[VAL_41]]
+// CHECK: %[[VAL_53:.*]] = and i64 %[[VAL_48]], -4
+// CHECK: br label %[[VAL_42]]
+// CHECK: omp.type.alloc.else: ; preds = %[[VAL_41]]
+// CHECK: %[[VAL_54:.*]] = icmp eq i64 %[[VAL_49]], 1
+// CHECK: br i1 %[[VAL_54]], label %[[VAL_55:.*]], label %[[VAL_56:.*]]
+// CHECK: omp.type.to: ; preds = %[[VAL_52]]
+// CHECK: %[[VAL_57:.*]] = and i64 %[[VAL_48]], -3
+// CHECK: br label %[[VAL_42]]
+// CHECK: omp.type.to.else: ; preds = %[[VAL_52]]
+// CHECK: %[[VAL_58:.*]] = icmp eq i64 %[[VAL_49]], 2
+// CHECK: br i1 %[[VAL_58]], label %[[VAL_59:.*]], label %[[VAL_42]]
+// CHECK: omp.type.from: ; preds = %[[VAL_56]]
+// CHECK: %[[VAL_60:.*]] = and i64 %[[VAL_48]], -2
+// CHECK: br label %[[VAL_42]]
+// CHECK: omp.type.end: ; preds = %[[VAL_59]], %[[VAL_56]], %[[VAL_55]], %[[VAL_51]]
+// CHECK: %[[VAL_61:.*]] = phi i64 [ %[[VAL_53]], %[[VAL_51]] ], [ %[[VAL_57]], %[[VAL_55]] ], [ %[[VAL_60]], %[[VAL_59]] ], [ %[[VAL_48]], %[[VAL_56]] ]
+// CHECK: call void @__tgt_push_mapper_component(ptr %[[VAL_37]], ptr %[[VAL_43]], ptr %[[VAL_45]], i64 4, i64 %[[VAL_61]], ptr @2)
+// CHECK: %[[VAL_44]] = getelementptr %[[VAL_18]], ptr %[[VAL_43]], i32 1
+// CHECK: %[[VAL_62:.*]] = icmp eq ptr %[[VAL_44]], %[[VAL_17]]
+// CHECK: br i1 %[[VAL_62]], label %[[VAL_63:.*]], label %[[VAL_41]]
+// CHECK: omp.arraymap.exit: ; preds = %[[VAL_42]]
+// CHECK: %[[VAL_64:.*]] = icmp sgt i64 %[[VAL_15]], 1
+// CHECK: %[[VAL_65:.*]] = and i64 %[[VAL_22]], 8
+// CHECK: %[[VAL_66:.*]] = icmp ne i64 %[[VAL_65]], 0
+// CHECK: %[[VAL_67:.*]] = and i1 %[[VAL_64]], %[[VAL_66]]
+// CHECK: br i1 %[[VAL_67]], label %[[VAL_68:.*]], label %[[VAL_40]]
+// CHECK: .omp.array..del: ; preds = %[[VAL_63]]
+// CHECK: %[[VAL_69:.*]] = mul nuw i64 %[[VAL_15]], 4
+// CHECK: %[[VAL_70:.*]] = and i64 %[[VAL_22]], -4
+// CHECK: %[[VAL_71:.*]] = or i64 %[[VAL_70]], 512
+// CHECK: call void @__tgt_push_mapper_component(ptr %[[VAL_37]], ptr %[[VAL_24]], ptr %[[VAL_19]], i64 %[[VAL_69]], i64 %[[VAL_71]], ptr %[[VAL_38]])
+// CHECK: br label %[[VAL_40]]
+// CHECK: omp.done: ; preds = %[[VAL_68]], %[[VAL_63]], %[[VAL_32]]
+// CHECK: ret void
diff --git a/offload/test/offloading/fortran/target-custom-mapper.f90 b/offload/test/offloading/fortran/target-custom-mapper.f90
new file mode 100644
index 00000000000000..5699a0613d9abb
--- /dev/null
+++ b/offload/test/offloading/fortran/target-custom-mapper.f90
@@ -0,0 +1,46 @@
+! Offloading test checking lowering of arrays with dynamic extents.
+! REQUIRES: flang, amdgpu
+
+! RUN: %libomptarget-compile-fortran-run-and-check-generic
+
+program test_openmp_mapper
+ implicit none
+ integer, parameter :: n = 1024
+ type :: mytype
+ integer :: data(n)
+ end type mytype
+
+ ! Declare a custom mapper for the derived type `mytype` with the name `my_mapper`
+ !$omp declare mapper(my_mapper : mytype :: t) map(to: t%data)
+
+ type(mytype) :: obj
+ integer :: i, sum_host, sum_device
+
+ ! Initialize the host data
+ do i = 1, n
+ obj%data(i) = 1
+ end do
+
+ ! Compute the sum on the host for verification
+ sum_host = sum(obj%data)
+
+ ! Offload computation to the device using the named mapper `my_mapper`
+ sum_device = 0
+ !$omp target map(tofrom: sum_device) map(mapper(my_mapper) : obj)
+ do i = 1, n
+ sum_device = sum_device + obj%data(i)
+ end do
+ !$omp end target
+
+ ! Check results
+ print *, "Sum on host: ", sum_host
+ print *, "Sum on device: ", sum_device
+
+ if (sum_device == sum_host) then
+ print *, "Test passed!"
+ else
+ print *, "Test failed!"
+ end if
+end program test_openmp_mapper
+
+! CHECK: Test passed!
>From c4569ed1cdc7c16339fa03a32443c4eb74a47949 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Tue, 28 Jan 2025 15:01:27 +0000
Subject: [PATCH 4/5] Fix IRBuilderTest failure.
---
clang/lib/CodeGen/CGOpenMPRuntime.cpp | 2 --
llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h | 1 -
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 1 -
llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | 15 +++++++++------
.../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 2 --
5 files changed, 9 insertions(+), 12 deletions(-)
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 0a13581dcb1700..7492a1cab8803a 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -32,12 +32,10 @@
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
-#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/AtomicOrdering.h"
-#include "llvm/Support/Error.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cstdint>
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 4e80bff6db4553..4762a836f9dade 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -22,7 +22,6 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Allocator.h"
-#include "llvm/Support/Error.h"
#include "llvm/TargetParser/Triple.h"
#include <forward_list>
#include <map>
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index be53dbbf8addf3..5fa946626a8b06 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -47,7 +47,6 @@
#include "llvm/IR/Value.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Target/TargetMachine.h"
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index a33e1533dede43..785876d3ca1d1f 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6186,8 +6186,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
};
auto CustomMapperCB = [&](unsigned int I) { return nullptr; };
- llvm::OpenMPIRBuilder::TargetDataInfo Info(/*RequiresDevicePointerInfo=*/true,
- /*SeparateBeginEndCalls=*/true);
+ llvm::OpenMPIRBuilder::TargetDataInfo Info(
+ /*RequiresDevicePointerInfo=*/false,
+ /*SeparateBeginEndCalls=*/true);
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
@@ -6302,8 +6303,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
F->getEntryBlock().getFirstInsertionPt());
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
/*Line=*/3, /*Count=*/0);
- llvm::OpenMPIRBuilder::TargetDataInfo Info(/*RequiresDevicePointerInfo=*/true,
- /*SeparateBeginEndCalls=*/true);
+ llvm::OpenMPIRBuilder::TargetDataInfo Info(
+ /*RequiresDevicePointerInfo=*/false,
+ /*SeparateBeginEndCalls=*/true);
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, Info, EntryInfo,
@@ -6462,8 +6464,9 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
F->getEntryBlock().getFirstInsertionPt());
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
/*Line=*/3, /*Count=*/0);
- llvm::OpenMPIRBuilder::TargetDataInfo Info(/*RequiresDevicePointerInfo=*/true,
- /*SeparateBeginEndCalls=*/true);
+ llvm::OpenMPIRBuilder::TargetDataInfo Info(
+ /*RequiresDevicePointerInfo=*/false,
+ /*SeparateBeginEndCalls=*/true);
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, Info, EntryInfo,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 024cc15518b5ac..af831dbb243f3e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -28,10 +28,8 @@
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
#include "llvm/IR/DebugInfoMetadata.h"
-#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/ReplaceConstant.h"
-#include "llvm/Support/Error.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
>From 35e6331f620665260e222225ed048463736b09e6 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Fri, 31 Jan 2025 10:34:04 +0000
Subject: [PATCH 5/5] Address reviewer comments.
---
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 4 +-
.../fortran/target-custom-mapper.f90 | 77 ++++++++++---------
2 files changed, 41 insertions(+), 40 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index af831dbb243f3e..d78b7ef9462917 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3449,7 +3449,7 @@ static llvm::Expected<llvm::Function *>
getOrCreateUserDefinedMapperFunc(Operation *declMapperOp,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
- llvm::DenseMap<const Operation *, llvm::Function *> userDefMapperMap;
+ static llvm::DenseMap<const Operation *, llvm::Function *> userDefMapperMap;
auto iter = userDefMapperMap.find(declMapperOp);
if (iter != userDefMapperMap.end())
return iter->second;
@@ -3458,7 +3458,7 @@ getOrCreateUserDefinedMapperFunc(Operation *declMapperOp,
if (!mapperFunc)
return mapperFunc.takeError();
userDefMapperMap.try_emplace(declMapperOp, *mapperFunc);
- return userDefMapperMap.lookup(declMapperOp);
+ return mapperFunc;
}
static llvm::Expected<llvm::Function *>
diff --git a/offload/test/offloading/fortran/target-custom-mapper.f90 b/offload/test/offloading/fortran/target-custom-mapper.f90
index 5699a0613d9abb..f81ec538f565d8 100644
--- a/offload/test/offloading/fortran/target-custom-mapper.f90
+++ b/offload/test/offloading/fortran/target-custom-mapper.f90
@@ -4,43 +4,44 @@
! RUN: %libomptarget-compile-fortran-run-and-check-generic
program test_openmp_mapper
- implicit none
- integer, parameter :: n = 1024
- type :: mytype
- integer :: data(n)
- end type mytype
-
- ! Declare a custom mapper for the derived type `mytype` with the name `my_mapper`
- !$omp declare mapper(my_mapper : mytype :: t) map(to: t%data)
-
- type(mytype) :: obj
- integer :: i, sum_host, sum_device
-
- ! Initialize the host data
- do i = 1, n
- obj%data(i) = 1
- end do
-
- ! Compute the sum on the host for verification
- sum_host = sum(obj%data)
-
- ! Offload computation to the device using the named mapper `my_mapper`
- sum_device = 0
- !$omp target map(tofrom: sum_device) map(mapper(my_mapper) : obj)
- do i = 1, n
- sum_device = sum_device + obj%data(i)
- end do
- !$omp end target
-
- ! Check results
- print *, "Sum on host: ", sum_host
- print *, "Sum on device: ", sum_device
-
- if (sum_device == sum_host) then
- print *, "Test passed!"
- else
- print *, "Test failed!"
- end if
-end program test_openmp_mapper
+ implicit none
+ integer, parameter :: n = 1024
+ type :: mytype
+ integer :: data(n)
+ end type mytype
+
+ ! Declare custom mappers for the derived type `mytype`
+ !$omp declare mapper(my_mapper1 : mytype :: t) map(to: t%data)
+ !$omp declare mapper(my_mapper2 : mytype :: t) map(mapper(my_mapper1): t%data)
+
+ type(mytype) :: obj
+ integer :: i, sum_host, sum_device
+
+ ! Initialize the host data
+ do i = 1, n
+ obj%data(i) = 1
+ end do
+
+ ! Compute the sum on the host for verification
+ sum_host = sum(obj%data)
+
+ ! Offload computation to the device using the named mapper `my_mapper2`
+ sum_device = 0
+ !$omp target map(tofrom: sum_device) map(mapper(my_mapper2) : obj)
+ do i = 1, n
+ sum_device = sum_device + obj%data(i)
+ end do
+ !$omp end target
+
+ ! Check results
+ print *, "Sum on host: ", sum_host
+ print *, "Sum on device: ", sum_device
+
+ if (sum_device == sum_host) then
+ print *, "Test passed!"
+ else
+ print *, "Test failed!"
+ end if
+ end program test_openmp_mapper
! CHECK: Test passed!
More information about the llvm-branch-commits
mailing list