[Openmp-commits] [openmp] 894531f - [Libomptarget] Add utility functions for loading an ELF symbol by name

Joseph Huber via Openmp-commits openmp-commits at lists.llvm.org
Wed Sep 7 10:39:02 PDT 2022


Author: Joseph Huber
Date: 2022-09-07T12:38:50-05:00
New Revision: 894531f59beb03c17a6e11e5b9c9995182b8d727

URL: https://github.com/llvm/llvm-project/commit/894531f59beb03c17a6e11e5b9c9995182b8d727
DIFF: https://github.com/llvm/llvm-project/commit/894531f59beb03c17a6e11e5b9c9995182b8d727.diff

LOG: [Libomptarget] Add utility functions for loading an ELF symbol by name

The `SHT_HASH` sections in an ELF are used to look up a symbol in the
symbol table using a symbol's name. This is done by obtaining the
`SHT_HASH` section and using its `sh_link` attribute to access the
associated symbol table, from which we can access the string table
containing the associated name. We can then search for the symbol using
the hash of the name and the buckets and chains in the hash table
itself

This patch adds utility functions that allow us to look up a symbol in
an ELF file by name. It will first attempt to look through the hash
tables, and then search the section tables manually if failed. This
allows us to pull out constants necessary for setting up offloading
without first loading the object.

Reviewed By: JonChesterfield

Differential Revision: https://reviews.llvm.org/D131309

Added: 
    openmp/libomptarget/plugins/common/elf_common/ELFSymbols.cpp
    openmp/libomptarget/plugins/common/elf_common/ELFSymbols.h

Modified: 
    openmp/libomptarget/plugins/amdgpu/src/rtl.cpp
    openmp/libomptarget/plugins/common/elf_common/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp
index 4f4881f04673c..b0e48d2351955 100644
--- a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp
+++ b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp
@@ -10,6 +10,13 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Frontend/OpenMP/OMPGridValues.h"
+#include "llvm/Object/ELF.h"
+#include "llvm/Object/ELFObjectFile.h"
+
 #include <algorithm>
 #include <assert.h>
 #include <cstdio>
@@ -24,6 +31,7 @@
 #include <unordered_map>
 #include <vector>
 
+#include "ELFSymbols.h"
 #include "impl_runtime.h"
 #include "interop_hsa.h"
 
@@ -35,12 +43,8 @@
 #include "omptargetplugin.h"
 #include "print_tracing.h"
 
-#include "llvm/ADT/StringMap.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Frontend/OpenMP/OMPConstants.h"
-#include "llvm/Frontend/OpenMP/OMPGridValues.h"
-
 using namespace llvm;
+using namespace llvm::object;
 
 // hostrpc interface, FIXME: consider moving to its own include these are
 // statically linked into amdgpu/plugin if present from hostrpc_services.a,
@@ -1600,128 +1604,53 @@ template <typename T> bool enforceUpperBound(T *Value, T Upper) {
   return Changed;
 }
 
-Elf64_Shdr *findOnlyShtHash(Elf *Elf) {
-  size_t N;
-  int Rc = elf_getshdrnum(Elf, &N);
-  if (Rc != 0) {
-    return nullptr;
-  }
-
-  Elf64_Shdr *Result = nullptr;
-  for (size_t I = 0; I < N; I++) {
-    Elf_Scn *Scn = elf_getscn(Elf, I);
-    if (Scn) {
-      Elf64_Shdr *Shdr = elf64_getshdr(Scn);
-      if (Shdr) {
-        if (Shdr->sh_type == SHT_HASH) {
-          if (Result == nullptr) {
-            Result = Shdr;
-          } else {
-            // multiple SHT_HASH sections not handled
-            return nullptr;
-          }
-        }
-      }
-    }
-  }
-  return Result;
-}
-
-const Elf64_Sym *elfLookup(Elf *Elf, char *Base, Elf64_Shdr *SectionHash,
-                           const char *Symname) {
-
-  assert(SectionHash);
-  size_t SectionSymtabIndex = SectionHash->sh_link;
-  Elf64_Shdr *SectionSymtab =
-      elf64_getshdr(elf_getscn(Elf, SectionSymtabIndex));
-  size_t SectionStrtabIndex = SectionSymtab->sh_link;
-
-  const Elf64_Sym *Symtab =
-      reinterpret_cast<const Elf64_Sym *>(Base + SectionSymtab->sh_offset);
-
-  const uint32_t *Hashtab =
-      reinterpret_cast<const uint32_t *>(Base + SectionHash->sh_offset);
-
-  // Layout:
-  // nbucket
-  // nchain
-  // bucket[nbucket]
-  // chain[nchain]
-  uint32_t Nbucket = Hashtab[0];
-  const uint32_t *Bucket = &Hashtab[2];
-  const uint32_t *Chain = &Hashtab[Nbucket + 2];
-
-  const size_t Max = strlen(Symname) + 1;
-  const uint32_t Hash = elf_hash(Symname);
-  for (uint32_t I = Bucket[Hash % Nbucket]; I != 0; I = Chain[I]) {
-    char *N = elf_strptr(Elf, SectionStrtabIndex, Symtab[I].st_name);
-    if (strncmp(Symname, N, Max) == 0) {
-      return &Symtab[I];
-    }
-  }
-
-  return nullptr;
-}
-
 struct SymbolInfo {
-  void *Addr = nullptr;
+  const void *Addr = nullptr;
   uint32_t Size = UINT32_MAX;
   uint32_t ShType = SHT_NULL;
 };
 
-int getSymbolInfoWithoutLoading(Elf *Elf, char *Base, const char *Symname,
-                                SymbolInfo *Res) {
-  if (elf_kind(Elf) != ELF_K_ELF) {
-    return 1;
-  }
-
-  Elf64_Shdr *SectionHash = findOnlyShtHash(Elf);
-  if (!SectionHash) {
-    return 1;
-  }
-
-  const Elf64_Sym *Sym = elfLookup(Elf, Base, SectionHash, Symname);
-  if (!Sym) {
+int getSymbolInfoWithoutLoading(const ELFObjectFile<ELF64LE> &ELFObj,
+                                StringRef SymName, SymbolInfo *Res) {
+  auto SymOrErr = getELFSymbol(ELFObj, SymName);
+  if (!SymOrErr) {
+    std::string ErrorString = toString(SymOrErr.takeError());
+    DP("Failed ELF lookup: %s\n", ErrorString.c_str());
     return 1;
   }
-
-  if (Sym->st_size > UINT32_MAX) {
-    return 1;
-  }
-
-  if (Sym->st_shndx == SHN_UNDEF) {
-    return 1;
-  }
-
-  Elf_Scn *Section = elf_getscn(Elf, Sym->st_shndx);
-  if (!Section) {
+  if (!*SymOrErr)
     return 1;
-  }
 
-  Elf64_Shdr *Header = elf64_getshdr(Section);
-  if (!Header) {
+  auto SymSecOrErr = ELFObj.getELFFile().getSection((*SymOrErr)->st_shndx);
+  if (!SymSecOrErr) {
+    std::string ErrorString = toString(SymOrErr.takeError());
+    DP("Failed ELF lookup: %s\n", ErrorString.c_str());
     return 1;
   }
 
-  Res->Addr = Sym->st_value + Base;
-  Res->Size = static_cast<uint32_t>(Sym->st_size);
-  Res->ShType = Header->sh_type;
+  Res->Addr = (*SymOrErr)->st_value + ELFObj.getELFFile().base();
+  Res->Size = static_cast<uint32_t>((*SymOrErr)->st_size);
+  Res->ShType = static_cast<uint32_t>((*SymSecOrErr)->sh_type);
   return 0;
 }
 
-int getSymbolInfoWithoutLoading(char *Base, size_t ImgSize, const char *Symname,
+int getSymbolInfoWithoutLoading(char *Base, size_t ImgSize, const char *SymName,
                                 SymbolInfo *Res) {
-  Elf *Elf = elf_memory(Base, ImgSize);
-  if (Elf) {
-    int Rc = getSymbolInfoWithoutLoading(Elf, Base, Symname, Res);
-    elf_end(Elf);
-    return Rc;
+  StringRef Buffer = StringRef(Base, ImgSize);
+  auto ElfOrErr = ObjectFile::createELFObjectFile(MemoryBufferRef(Buffer, ""),
+                                                  /*InitContent=*/false);
+  if (!ElfOrErr) {
+    REPORT("Failed to load ELF: %s\n", toString(ElfOrErr.takeError()).c_str());
+    return 1;
   }
+
+  if (const auto *ELFObj = dyn_cast<ELF64LEObjectFile>(ElfOrErr->get()))
+    return getSymbolInfoWithoutLoading(*ELFObj, SymName, Res);
   return 1;
 }
 
 hsa_status_t interopGetSymbolInfo(char *Base, size_t ImgSize,
-                                  const char *SymName, void **VarAddr,
+                                  const char *SymName, const void **VarAddr,
                                   uint32_t *VarSize) {
   SymbolInfo SI;
   int Rc = getSymbolInfoWithoutLoading(Base, ImgSize, SymName, &SI);
@@ -2492,7 +2421,7 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t DeviceId,
     KernDescNameStr += "_kern_desc";
     const char *KernDescName = KernDescNameStr.c_str();
 
-    void *KernDescPtr;
+    const void *KernDescPtr;
     uint32_t KernDescSize;
     void *CallStackAddr = nullptr;
     Err = interopGetSymbolInfo((char *)Image->ImageStart, ImgSize, KernDescName,
@@ -2531,7 +2460,7 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t DeviceId,
       WGSizeNameStr += "_wg_size";
       const char *WGSizeName = WGSizeNameStr.c_str();
 
-      void *WGSizePtr;
+      const void *WGSizePtr;
       uint32_t WGSize;
       Err = interopGetSymbolInfo((char *)Image->ImageStart, ImgSize, WGSizeName,
                                  &WGSizePtr, &WGSize);
@@ -2570,7 +2499,7 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t DeviceId,
     ExecModeNameStr += "_exec_mode";
     const char *ExecModeName = ExecModeNameStr.c_str();
 
-    void *ExecModePtr;
+    const void *ExecModePtr;
     uint32_t VarSize;
     Err = interopGetSymbolInfo((char *)Image->ImageStart, ImgSize, ExecModeName,
                                &ExecModePtr, &VarSize);

diff  --git a/openmp/libomptarget/plugins/common/elf_common/CMakeLists.txt b/openmp/libomptarget/plugins/common/elf_common/CMakeLists.txt
index 9ea292620f426..f6aa809b46169 100644
--- a/openmp/libomptarget/plugins/common/elf_common/CMakeLists.txt
+++ b/openmp/libomptarget/plugins/common/elf_common/CMakeLists.txt
@@ -10,7 +10,7 @@
 #
 ##===----------------------------------------------------------------------===##
 
-add_library(elf_common OBJECT elf_common.cpp)
+add_library(elf_common OBJECT elf_common.cpp ELFSymbols.cpp)
 
 # Build elf_common with PIC to be able to link it with plugin shared libraries.
 set_property(TARGET elf_common PROPERTY POSITION_INDEPENDENT_CODE ON)

diff  --git a/openmp/libomptarget/plugins/common/elf_common/ELFSymbols.cpp b/openmp/libomptarget/plugins/common/elf_common/ELFSymbols.cpp
new file mode 100644
index 0000000000000..5c31c769f9149
--- /dev/null
+++ b/openmp/libomptarget/plugins/common/elf_common/ELFSymbols.cpp
@@ -0,0 +1,193 @@
+#include "ELFSymbols.h"
+
+using namespace llvm;
+using namespace llvm::object;
+using namespace llvm::ELF;
+
+template <class ELFT>
+static Expected<const typename ELFT::Sym *>
+getSymbolFromGnuHashTable(StringRef Name, const typename ELFT::GnuHash &HashTab,
+                          ArrayRef<typename ELFT::Sym> SymTab,
+                          StringRef StrTab) {
+  const uint32_t NameHash = hashGnu(Name);
+  const typename ELFT::Word NBucket = HashTab.nbuckets;
+  const typename ELFT::Word SymOffset = HashTab.symndx;
+  ArrayRef<typename ELFT::Off> Filter = HashTab.filter();
+  ArrayRef<typename ELFT::Word> Bucket = HashTab.buckets();
+  ArrayRef<typename ELFT::Word> Chain = HashTab.values(SymTab.size());
+
+  // Check the bloom filter and exit early if the symbol is not present.
+  uint64_t ElfClassBits = ELFT::Is64Bits ? 64 : 32;
+  typename ELFT::Off Word =
+      Filter[(NameHash / ElfClassBits) % HashTab.maskwords];
+  uint64_t Mask = (0x1ull << (NameHash % ElfClassBits)) |
+                  (0x1ull << ((NameHash >> HashTab.shift2) % ElfClassBits));
+  if ((Word & Mask) != Mask)
+    return nullptr;
+
+  // The symbol may or may not be present, check the hash values.
+  for (typename ELFT::Word I = Bucket[NameHash % NBucket];
+       I >= SymOffset && I < SymTab.size(); I = I + 1) {
+    const uint32_t ChainHash = Chain[I - SymOffset];
+
+    if ((NameHash | 0x1) != (ChainHash | 0x1))
+      continue;
+
+    if (SymTab[I].st_name >= StrTab.size())
+      return createError("symbol [index " + Twine(I) +
+                         "] has invalid st_name: " + Twine(SymTab[I].st_name));
+    if (StrTab.drop_front(SymTab[I].st_name).data() == Name)
+      return &SymTab[I];
+
+    if (ChainHash & 0x1)
+      return nullptr;
+  }
+  return nullptr;
+}
+
+template <class ELFT>
+static Expected<const typename ELFT::Sym *>
+getSymbolFromSysVHashTable(StringRef Name, const typename ELFT::Hash &HashTab,
+                           ArrayRef<typename ELFT::Sym> SymTab,
+                           StringRef StrTab) {
+  const uint32_t Hash = hashSysV(Name);
+  const typename ELFT::Word NBucket = HashTab.nbucket;
+  ArrayRef<typename ELFT::Word> Bucket = HashTab.buckets();
+  ArrayRef<typename ELFT::Word> Chain = HashTab.chains();
+  for (typename ELFT::Word I = Bucket[Hash % NBucket]; I != ELF::STN_UNDEF;
+       I = Chain[I]) {
+    if (I >= SymTab.size())
+      return createError(
+          "symbol [index " + Twine(I) +
+          "] is greater than the number of symbols: " + Twine(SymTab.size()));
+    if (SymTab[I].st_name >= StrTab.size())
+      return createError("symbol [index " + Twine(I) +
+                         "] has invalid st_name: " + Twine(SymTab[I].st_name));
+
+    if (StrTab.drop_front(SymTab[I].st_name).data() == Name)
+      return &SymTab[I];
+  }
+  return nullptr;
+}
+
+template <class ELFT>
+static Expected<const typename ELFT::Sym *>
+getHashTableSymbol(const ELFFile<ELFT> &Elf, const typename ELFT::Shdr &Sec,
+                   StringRef Name) {
+  if (Sec.sh_type != ELF::SHT_HASH && Sec.sh_type != ELF::SHT_GNU_HASH)
+    return createError(
+        "invalid sh_type for hash table, expected SHT_HASH or SHT_GNU_HASH");
+  Expected<typename ELFT::ShdrRange> SectionsOrError = Elf.sections();
+  if (!SectionsOrError)
+    return SectionsOrError.takeError();
+
+  auto SymTabOrErr = getSection<ELFT>(*SectionsOrError, Sec.sh_link);
+  if (!SymTabOrErr)
+    return SymTabOrErr.takeError();
+
+  auto StrTabOrErr =
+      Elf.getStringTableForSymtab(**SymTabOrErr, *SectionsOrError);
+  if (!StrTabOrErr)
+    return StrTabOrErr.takeError();
+  StringRef StrTab = *StrTabOrErr;
+
+  auto SymsOrErr = Elf.symbols(*SymTabOrErr);
+  if (!SymsOrErr)
+    return SymsOrErr.takeError();
+  ArrayRef<typename ELFT::Sym> SymTab = *SymsOrErr;
+
+  // If this is a GNU hash table we verify its size and search the symbol
+  // table using the GNU hash table format.
+  if (Sec.sh_type == ELF::SHT_GNU_HASH) {
+    const typename ELFT::GnuHash *HashTab =
+        reinterpret_cast<const typename ELFT::GnuHash *>(Elf.base() +
+                                                         Sec.sh_offset);
+    if (Sec.sh_offset + Sec.sh_size >= Elf.getBufSize())
+      return createError("section has invalid sh_offset: " +
+                         Twine(Sec.sh_offset));
+    if (Sec.sh_size < sizeof(typename ELFT::GnuHash) ||
+        Sec.sh_size <
+            sizeof(typename ELFT::GnuHash) +
+                sizeof(typename ELFT::Word) * HashTab->maskwords +
+                sizeof(typename ELFT::Word) * HashTab->nbuckets +
+                sizeof(typename ELFT::Word) * (SymTab.size() - HashTab->symndx))
+      return createError("section has invalid sh_size: " + Twine(Sec.sh_size));
+    return getSymbolFromGnuHashTable<ELFT>(Name, *HashTab, SymTab, StrTab);
+  }
+
+  // If this is a Sys-V hash table we verify its size and search the symbol
+  // table using the Sys-V hash table format.
+  if (Sec.sh_type == ELF::SHT_HASH) {
+    const typename ELFT::Hash *HashTab =
+        reinterpret_cast<const typename ELFT::Hash *>(Elf.base() +
+                                                      Sec.sh_offset);
+    if (Sec.sh_offset + Sec.sh_size >= Elf.getBufSize())
+      return createError("section has invalid sh_offset: " +
+                         Twine(Sec.sh_offset));
+    if (Sec.sh_size < sizeof(typename ELFT::Hash) ||
+        Sec.sh_size < sizeof(typename ELFT::Hash) +
+                          sizeof(typename ELFT::Word) * HashTab->nbucket +
+                          sizeof(typename ELFT::Word) * HashTab->nchain)
+      return createError("section has invalid sh_size: " + Twine(Sec.sh_size));
+
+    return getSymbolFromSysVHashTable<ELFT>(Name, *HashTab, SymTab, StrTab);
+  }
+
+  return nullptr;
+}
+
+template <class ELFT>
+static Expected<const typename ELFT::Sym *>
+getSymTableSymbol(const ELFFile<ELFT> &Elf, const typename ELFT::Shdr &Sec,
+                  StringRef Name) {
+  if (Sec.sh_type != ELF::SHT_SYMTAB && Sec.sh_type != ELF::SHT_DYNSYM)
+    return createError(
+        "invalid sh_type for hash table, expected SHT_SYMTAB or SHT_DYNSYM");
+  Expected<typename ELFT::ShdrRange> SectionsOrError = Elf.sections();
+  if (!SectionsOrError)
+    return SectionsOrError.takeError();
+
+  auto StrTabOrErr = Elf.getStringTableForSymtab(Sec, *SectionsOrError);
+  if (!StrTabOrErr)
+    return StrTabOrErr.takeError();
+  StringRef StrTab = *StrTabOrErr;
+
+  auto SymsOrErr = Elf.symbols(&Sec);
+  if (!SymsOrErr)
+    return SymsOrErr.takeError();
+  ArrayRef<typename ELFT::Sym> SymTab = *SymsOrErr;
+
+  for (const typename ELFT::Sym &Sym : SymTab)
+    if (StrTab.drop_front(Sym.st_name).data() == Name)
+      return &Sym;
+
+  return nullptr;
+}
+
+Expected<const typename ELF64LE::Sym *>
+getELFSymbol(const ELFObjectFile<ELF64LE> &ELFObj, StringRef Name) {
+  // First try to look up the symbol via the hash table.
+  for (ELFSectionRef Sec : ELFObj.sections()) {
+    if (Sec.getType() != SHT_HASH && Sec.getType() != SHT_GNU_HASH)
+      continue;
+
+    auto HashTabOrErr = ELFObj.getELFFile().getSection(Sec.getIndex());
+    if (!HashTabOrErr)
+      return HashTabOrErr.takeError();
+    return getHashTableSymbol<ELF64LE>(ELFObj.getELFFile(), **HashTabOrErr,
+                                       Name);
+  }
+
+  // If this is an executable file check the entire standard symbol table.
+  for (ELFSectionRef Sec : ELFObj.sections()) {
+    if (Sec.getType() != SHT_SYMTAB)
+      continue;
+
+    auto SymTabOrErr = ELFObj.getELFFile().getSection(Sec.getIndex());
+    if (!SymTabOrErr)
+      return SymTabOrErr.takeError();
+    return getSymTableSymbol<ELF64LE>(ELFObj.getELFFile(), **SymTabOrErr, Name);
+  }
+
+  return nullptr;
+}

diff  --git a/openmp/libomptarget/plugins/common/elf_common/ELFSymbols.h b/openmp/libomptarget/plugins/common/elf_common/ELFSymbols.h
new file mode 100644
index 0000000000000..589cd3436c326
--- /dev/null
+++ b/openmp/libomptarget/plugins/common/elf_common/ELFSymbols.h
@@ -0,0 +1,27 @@
+//===-- ELFSymbols.h - ELF Symbol look-up functionality ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// ELF routines for obtaining a symbol from an Elf file without loading it.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_ELF_COMMON_ELF_SYMBOLS_H
+#define LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_ELF_COMMON_ELF_SYMBOLS_H
+
+#include "llvm/Object/ELF.h"
+#include "llvm/Object/ELFObjectFile.h"
+
+/// Returns the symbol associated with the \p Name in the \p ELFObj. It will
+/// first search for the hash sections to identify symbols from the hash table.
+/// If that fails it will fall back to a linear search in the case of an
+/// executable file without a hash table.
+llvm::Expected<const typename llvm::object::ELF64LE::Sym *>
+getELFSymbol(const llvm::object::ELFObjectFile<llvm::object::ELF64LE> &ELFObj,
+             llvm::StringRef Name);
+
+#endif


        


More information about the Openmp-commits mailing list