[Openmp-commits] [openmp] 7648b69 - [AMDGPU][Libomptarget] Move Kernel/Symbol info tables to RTLDeviceInfoTy
Pushpinder Singh via Openmp-commits
openmp-commits at lists.llvm.org
Wed May 26 03:02:42 PDT 2021
Author: Pushpinder Singh
Date: 2021-05-26T10:02:28Z
New Revision: 7648b6978e5539bcb43b3ca24a5a53e9c6a52c1e
URL: https://github.com/llvm/llvm-project/commit/7648b6978e5539bcb43b3ca24a5a53e9c6a52c1e
DIFF: https://github.com/llvm/llvm-project/commit/7648b6978e5539bcb43b3ca24a5a53e9c6a52c1e.diff
LOG: [AMDGPU][Libomptarget] Move Kernel/Symbol info tables to RTLDeviceInfoTy
Two globals KernelInfoTable & SymbolInfoTable are moved
into RTLDeviceInfoTy class.
This builds on the top of D102691.
[2/2]
Reviewed By: JonChesterfield
Differential Revision: https://reviews.llvm.org/D102692
Added:
Modified:
openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.cpp
openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.h
openmp/libomptarget/plugins/amdgpu/impl/internal.h
openmp/libomptarget/plugins/amdgpu/impl/system.cpp
openmp/libomptarget/plugins/amdgpu/src/rtl.cpp
Removed:
################################################################################
diff --git a/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.cpp b/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.cpp
index eb4a46c35a9b7..dc563ee40f7bf 100644
--- a/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.cpp
+++ b/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.cpp
@@ -8,10 +8,10 @@
using core::atl_is_atmi_initialized;
-atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place,
- const char *symbol,
- void **var_addr,
- unsigned int *var_size) {
+atmi_status_t atmi_interop_hsa_get_symbol_info(
+ const std::map<std::string, atl_symbol_info_t> &SymbolInfoTable,
+ atmi_mem_place_t place, const char *symbol, void **var_addr,
+ unsigned int *var_size) {
/*
// Typical usage:
void *var_addr;
@@ -32,9 +32,9 @@ atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place,
// get the symbol info
std::string symbolStr = std::string(symbol);
- if (SymbolInfoTable[place.dev_id].find(symbolStr) !=
- SymbolInfoTable[place.dev_id].end()) {
- atl_symbol_info_t info = SymbolInfoTable[place.dev_id][symbolStr];
+ auto It = SymbolInfoTable.find(symbolStr);
+ if (It != SymbolInfoTable.end()) {
+ atl_symbol_info_t info = It->second;
*var_addr = reinterpret_cast<void *>(info.addr);
*var_size = info.size;
return ATMI_STATUS_SUCCESS;
@@ -46,6 +46,7 @@ atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place,
}
atmi_status_t atmi_interop_hsa_get_kernel_info(
+ const std::map<std::string, atl_kernel_info_t> &KernelInfoTable,
atmi_mem_place_t place, const char *kernel_name,
hsa_executable_symbol_info_t kernel_info, uint32_t *value) {
/*
@@ -68,9 +69,9 @@ atmi_status_t atmi_interop_hsa_get_kernel_info(
atmi_status_t status = ATMI_STATUS_SUCCESS;
// get the kernel info
std::string kernelStr = std::string(kernel_name);
- if (KernelInfoTable[place.dev_id].find(kernelStr) !=
- KernelInfoTable[place.dev_id].end()) {
- atl_kernel_info_t info = KernelInfoTable[place.dev_id][kernelStr];
+ auto It = KernelInfoTable.find(kernelStr);
+ if (It != KernelInfoTable.end()) {
+ atl_kernel_info_t info = It->second;
switch (kernel_info) {
case HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE:
*value = info.group_segment_size;
diff --git a/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.h b/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.h
index c0f588215e8a2..20da1173a8dba 100644
--- a/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.h
+++ b/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.h
@@ -9,6 +9,10 @@
#include "atmi_runtime.h"
#include "hsa.h"
#include "hsa_ext_amd.h"
+#include "internal.h"
+
+#include <map>
+#include <string>
#ifdef __cplusplus
extern "C" {
@@ -44,11 +48,10 @@ extern "C" {
*
* @retval ::ATMI_STATUS_UNKNOWN The function encountered errors.
*/
-atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place,
- const char *symbol,
- void **var_addr,
- unsigned int *var_size);
-
+atmi_status_t atmi_interop_hsa_get_symbol_info(
+ const std::map<std::string, atl_symbol_info_t> &SymbolInfoTable,
+ atmi_mem_place_t place, const char *symbol, void **var_addr,
+ unsigned int *var_size);
/**
* @brief Get the HSA-specific kernel info from a kernel name
*
@@ -75,8 +78,10 @@ atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place,
* @retval ::ATMI_STATUS_UNKNOWN The function encountered errors.
*/
atmi_status_t atmi_interop_hsa_get_kernel_info(
+ const std::map<std::string, atl_kernel_info_t> &KernelInfoTable,
atmi_mem_place_t place, const char *kernel_name,
hsa_executable_symbol_info_t info, uint32_t *value);
+
/** @} */
#ifdef __cplusplus
diff --git a/openmp/libomptarget/plugins/amdgpu/impl/internal.h b/openmp/libomptarget/plugins/amdgpu/impl/internal.h
index ef06839873498..98d9ee487fe96 100644
--- a/openmp/libomptarget/plugins/amdgpu/impl/internal.h
+++ b/openmp/libomptarget/plugins/amdgpu/impl/internal.h
@@ -106,9 +106,6 @@ typedef struct atl_symbol_info_s {
uint32_t size;
} atl_symbol_info_t;
-extern std::vector<std::map<std::string, atl_kernel_info_t>> KernelInfoTable;
-extern std::vector<std::map<std::string, atl_symbol_info_t>> SymbolInfoTable;
-
// ---------------------- Kernel End -------------
namespace core {
diff --git a/openmp/libomptarget/plugins/amdgpu/impl/system.cpp b/openmp/libomptarget/plugins/amdgpu/impl/system.cpp
index f3a7d20be0ddd..ac171022bad41 100644
--- a/openmp/libomptarget/plugins/amdgpu/impl/system.cpp
+++ b/openmp/libomptarget/plugins/amdgpu/impl/system.cpp
@@ -146,9 +146,6 @@ ATLMachine g_atl_machine;
std::vector<hsa_amd_memory_pool_t> atl_gpu_kernarg_pools;
-std::vector<std::map<std::string, atl_kernel_info_t>> KernelInfoTable;
-std::vector<std::map<std::string, atl_symbol_info_t>> SymbolInfoTable;
-
bool g_atmi_initialized = false;
/*
@@ -208,15 +205,6 @@ atmi_status_t Runtime::Initialize() {
atmi_status_t Runtime::Finalize() {
atmi_status_t rc = ATMI_STATUS_SUCCESS;
- for (uint32_t i = 0; i < SymbolInfoTable.size(); i++) {
- SymbolInfoTable[i].clear();
- }
- SymbolInfoTable.clear();
- for (uint32_t i = 0; i < KernelInfoTable.size(); i++) {
- KernelInfoTable[i].clear();
- }
- KernelInfoTable.clear();
-
atl_reset_atmi_initialized();
hsa_status_t err = hsa_shut_down();
if (err != HSA_STATUS_SUCCESS) {
@@ -556,13 +544,6 @@ hsa_status_t init_hsa() {
return err;
}
- int gpu_count = g_atl_machine.processorCount<ATLGPUProcessor>();
- KernelInfoTable.resize(gpu_count);
- SymbolInfoTable.resize(gpu_count);
- for (uint32_t i = 0; i < SymbolInfoTable.size(); i++)
- SymbolInfoTable[i].clear();
- for (uint32_t i = 0; i < KernelInfoTable.size(); i++)
- KernelInfoTable[i].clear();
atlc.g_hsa_initialized = true;
DEBUG_PRINT("done\n");
}
@@ -835,8 +816,9 @@ int populate_kernelArgMD(msgpack::byte_range args_element,
}
} // namespace
-static hsa_status_t get_code_object_custom_metadata(void *binary,
- size_t binSize, int gpu) {
+static hsa_status_t get_code_object_custom_metadata(
+ void *binary, size_t binSize, int gpu,
+ std::map<std::string, atl_kernel_info_t> &KernelInfoTable) {
// parse code object with
diff erent keys from v2
// also, the kernel name is not the same as the symbol name -- so a
// symbol->name map is needed
@@ -1003,14 +985,16 @@ static hsa_status_t get_code_object_custom_metadata(void *binary,
kernel_segment_size, info.kernel_segment_size);
// kernel received, now add it to the kernel info table
- KernelInfoTable[gpu][kernelName] = info;
+ KernelInfoTable[kernelName] = info;
}
return HSA_STATUS_SUCCESS;
}
-static hsa_status_t populate_InfoTables(hsa_executable_symbol_t symbol,
- int gpu) {
+static hsa_status_t
+populate_InfoTables(hsa_executable_symbol_t symbol, int gpu,
+ std::map<std::string, atl_kernel_info_t> &KernelInfoTable,
+ std::map<std::string, atl_symbol_info_t> &SymbolInfoTable) {
hsa_symbol_kind_t type;
uint32_t name_length;
@@ -1047,11 +1031,16 @@ static hsa_status_t populate_InfoTables(hsa_executable_symbol_t symbol,
// by now, the kernel info table should already have an entry
// because the non-ROCr custom code object parsing is called before
// iterating over the code object symbols using ROCr
- if (KernelInfoTable[gpu].find(kernelName) == KernelInfoTable[gpu].end()) {
- return HSA_STATUS_ERROR;
+ if (KernelInfoTable.find(kernelName) == KernelInfoTable.end()) {
+ if (HSA_STATUS_ERROR_INVALID_CODE_OBJECT != HSA_STATUS_SUCCESS) {
+ printf("[%s:%d] %s failed: %s\n", __FILE__, __LINE__,
+ "Finding the entry kernel info table",
+ get_error_string(HSA_STATUS_ERROR_INVALID_CODE_OBJECT));
+ exit(1);
+ }
}
// found, so assign and update
- info = KernelInfoTable[gpu][kernelName];
+ info = KernelInfoTable[kernelName];
/* Extract dispatch information from the symbol */
err = hsa_executable_symbol_get_info(
@@ -1089,7 +1078,7 @@ static hsa_status_t populate_InfoTables(hsa_executable_symbol_t symbol,
info.private_segment_size, info.kernel_segment_size);
// assign it back to the kernel info table
- KernelInfoTable[gpu][kernelName] = info;
+ KernelInfoTable[kernelName] = info;
free(name);
} else if (type == HSA_SYMBOL_KIND_VARIABLE) {
err = hsa_executable_symbol_get_info(
@@ -1135,7 +1124,7 @@ static hsa_status_t populate_InfoTables(hsa_executable_symbol_t symbol,
if (err != HSA_STATUS_SUCCESS) {
return err;
}
- SymbolInfoTable[gpu][std::string(name)] = info;
+ SymbolInfoTable[std::string(name)] = info;
free(name);
} else {
DEBUG_PRINT("Symbol is an indirect function\n");
@@ -1143,7 +1132,9 @@ static hsa_status_t populate_InfoTables(hsa_executable_symbol_t symbol,
return HSA_STATUS_SUCCESS;
}
-atmi_status_t Runtime::RegisterModuleFromMemory(
+atmi_status_t RegisterModuleFromMemory(
+ std::map<std::string, atl_kernel_info_t> &KernelInfoTable,
+ std::map<std::string, atl_symbol_info_t> &SymbolInfoTable,
void *module_bytes, size_t module_size, atmi_place_t place,
atmi_status_t (*on_deserialized_data)(void *data, size_t size,
void *cb_state),
@@ -1183,7 +1174,8 @@ atmi_status_t Runtime::RegisterModuleFromMemory(
// Some metadata info is not available through ROCr API, so use custom
// code object metadata parsing to collect such metadata info
- err = get_code_object_custom_metadata(module_bytes, module_size, gpu);
+ err = get_code_object_custom_metadata(module_bytes, module_size, gpu,
+ KernelInfoTable);
if (err != HSA_STATUS_SUCCESS) {
DEBUG_PRINT("[%s:%d] %s failed: %s\n", __FILE__, __LINE__,
"Getting custom code object metadata",
@@ -1240,9 +1232,9 @@ atmi_status_t Runtime::RegisterModuleFromMemory(
err = hsa::executable_iterate_symbols(
executable,
[&](hsa_executable_t, hsa_executable_symbol_t symbol) -> hsa_status_t {
- return populate_InfoTables(symbol, gpu);
+ return populate_InfoTables(symbol, gpu, KernelInfoTable,
+ SymbolInfoTable);
});
-
if (err != HSA_STATUS_SUCCESS) {
printf("[%s:%d] %s failed: %s\n", __FILE__, __LINE__,
"Iterating over symbols for execuatable", get_error_string(err));
diff --git a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp
index b67f3cf45023b..4883288e0725c 100644
--- a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp
+++ b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp
@@ -86,6 +86,16 @@ int print_kernel_trace;
#include "elf_common.h"
+namespace core {
+atmi_status_t RegisterModuleFromMemory(
+ std::map<std::string, atl_kernel_info_t> &KernelInfo,
+ std::map<std::string, atl_symbol_info_t> &SymbolInfoTable, void *, size_t,
+ atmi_place_t,
+ atmi_status_t (*on_deserialized_data)(void *data, size_t size,
+ void *cb_state),
+ void *cb_state, std::vector<hsa_executable_t> &HSAExecutables);
+}
+
/// Keep entries table per device
struct FuncOrGblEntryTy {
__tgt_target_table Table;
@@ -339,6 +349,9 @@ class RTLDeviceInfoTy {
std::vector<hsa_executable_t> HSAExecutables;
+ std::vector<std::map<std::string, atl_kernel_info_t>> KernelInfoTable;
+ std::vector<std::map<std::string, atl_symbol_info_t>> SymbolInfoTable;
+
struct atmiFreePtrDeletor {
void operator()(void *p) {
atmi_free(p); // ignore failure to free
@@ -482,6 +495,8 @@ class RTLDeviceInfoTy {
NumTeams.resize(NumberOfDevices);
NumThreads.resize(NumberOfDevices);
deviceStateStore.resize(NumberOfDevices);
+ KernelInfoTable.resize(NumberOfDevices);
+ SymbolInfoTable.resize(NumberOfDevices);
for (int i = 0; i < NumberOfDevices; i++) {
HSAQueues[i] = nullptr;
@@ -993,15 +1008,17 @@ atmi_status_t interop_get_symbol_info(char *base, size_t img_size,
template <typename C>
atmi_status_t module_register_from_memory_to_place(
+ std::map<std::string, atl_kernel_info_t> &KernelInfoTable,
+ std::map<std::string, atl_symbol_info_t> &SymbolInfoTable,
void *module_bytes, size_t module_size, atmi_place_t place, C cb,
std::vector<hsa_executable_t> &HSAExecutables) {
auto L = [](void *data, size_t size, void *cb_state) -> atmi_status_t {
C *unwrapped = static_cast<C *>(cb_state);
return (*unwrapped)(data, size);
};
- return core::Runtime::RegisterModuleFromMemory(
- module_bytes, module_size, place, L, static_cast<void *>(&cb),
- HSAExecutables);
+ return core::RegisterModuleFromMemory(
+ KernelInfoTable, SymbolInfoTable, module_bytes, module_size, place, L,
+ static_cast<void *>(&cb), HSAExecutables);
}
} // namespace
@@ -1116,11 +1133,12 @@ struct device_environment {
DP("Setting global device environment after load (%u bytes)\n",
si.size);
int device_id = host_device_env.device_num;
-
+ auto &SymbolInfo = DeviceInfo.SymbolInfoTable[device_id];
void *state_ptr;
uint32_t state_ptr_size;
atmi_status_t err = atmi_interop_hsa_get_symbol_info(
- get_gpu_mem_place(device_id), sym(), &state_ptr, &state_ptr_size);
+ SymbolInfo, get_gpu_mem_place(device_id), sym(), &state_ptr,
+ &state_ptr_size);
if (err != ATMI_STATUS_SUCCESS) {
DP("failed to find %s in loaded image\n", sym());
return err;
@@ -1205,8 +1223,11 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t device_id,
auto env = device_environment(device_id, DeviceInfo.NumberOfDevices, image,
img_size);
+ auto &KernelInfo = DeviceInfo.KernelInfoTable[device_id];
+ auto &SymbolInfo = DeviceInfo.SymbolInfoTable[device_id];
atmi_status_t err = module_register_from_memory_to_place(
- (void *)image->ImageStart, img_size, get_gpu_place(device_id),
+ KernelInfo, SymbolInfo, (void *)image->ImageStart, img_size,
+ get_gpu_place(device_id),
[&](void *data, size_t size) {
if (image_contains_symbol(data, size, "needs_hostcall_buffer")) {
__atomic_store_n(&DeviceInfo.hostcall_required, true,
@@ -1241,9 +1262,10 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t device_id,
void *state_ptr;
uint32_t state_ptr_size;
+ auto &SymbolInfoMap = DeviceInfo.SymbolInfoTable[device_id];
atmi_status_t err = atmi_interop_hsa_get_symbol_info(
- get_gpu_mem_place(device_id), "omptarget_nvptx_device_State",
- &state_ptr, &state_ptr_size);
+ SymbolInfoMap, get_gpu_mem_place(device_id),
+ "omptarget_nvptx_device_State", &state_ptr, &state_ptr_size);
if (err != ATMI_STATUS_SUCCESS) {
DP("No device_state symbol found, skipping initialization\n");
@@ -1325,8 +1347,10 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t device_id,
void *varptr;
uint32_t varsize;
+ auto &SymbolInfoMap = DeviceInfo.SymbolInfoTable[device_id];
atmi_status_t err = atmi_interop_hsa_get_symbol_info(
- get_gpu_mem_place(device_id), e->name, &varptr, &varsize);
+ SymbolInfoMap, get_gpu_mem_place(device_id), e->name, &varptr,
+ &varsize);
if (err != ATMI_STATUS_SUCCESS) {
// Inform the user what symbol prevented offloading
@@ -1367,8 +1391,10 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t device_id,
atmi_mem_place_t place = get_gpu_mem_place(device_id);
uint32_t kernarg_segment_size;
+ auto &KernelInfoMap = DeviceInfo.KernelInfoTable[device_id];
atmi_status_t err = atmi_interop_hsa_get_kernel_info(
- place, e->name, HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE,
+ KernelInfoMap, place, e->name,
+ HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE,
&kernarg_segment_size);
// each arg is a void * in this openmp implementation
@@ -1794,6 +1820,7 @@ int32_t __tgt_rtl_run_target_team_region_locked(
KernelTy *KernelInfo = (KernelTy *)tgt_entry_ptr;
std::string kernel_name = std::string(KernelInfo->Name);
+ auto &KernelInfoTable = DeviceInfo.KernelInfoTable;
if (KernelInfoTable[device_id].find(kernel_name) ==
KernelInfoTable[device_id].end()) {
DP("Kernel %s not found\n", kernel_name.c_str());
More information about the Openmp-commits
mailing list