[Openmp-commits] [openmp] 2210c85 - Reapply [libomptarget] Support BE ELF files in plugins-nextgen (#85246)

Ulrich Weigand via Openmp-commits openmp-commits at lists.llvm.org
Fri Mar 15 10:29:37 PDT 2024


Author: Ulrich Weigand
Date: 2024-03-15T18:28:28+01:00
New Revision: 2210c85a664463cdc11e3b7990c9663c739e6060

URL: https://github.com/llvm/llvm-project/commit/2210c85a664463cdc11e3b7990c9663c739e6060
DIFF: https://github.com/llvm/llvm-project/commit/2210c85a664463cdc11e3b7990c9663c739e6060.diff

LOG: Reapply [libomptarget] Support BE ELF files in plugins-nextgen (#85246)

Code in plugins-nextgen reading ELF files is currently hard-coded to
assume a 64-bit little-endian ELF format. Unfortunately, this assumption
is even embedded in the interface between GlobalHandler and Utils/ELF
routines, which use ELF64LE types.

To fix this, I've refactored the interface to use generic types, in
particular by using (a unique_ptr to) ObjectFile instead of
ELF64LEObjectFile, and ELFSymbolRef instead of ELF64LE::Sym.

This allows properly templating over multiple ELF format variants inside
Utils/ELF; specifically, this patch adds support for 64-bit big-endian
ELF files in addition to 64-bit little-endian files.

Added: 
    

Modified: 
    openmp/libomptarget/plugins-nextgen/common/include/GlobalHandler.h
    openmp/libomptarget/plugins-nextgen/common/include/Utils/ELF.h
    openmp/libomptarget/plugins-nextgen/common/src/GlobalHandler.cpp
    openmp/libomptarget/plugins-nextgen/common/src/Utils/ELF.cpp
    openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp

Removed: 
    


################################################################################
diff  --git a/openmp/libomptarget/plugins-nextgen/common/include/GlobalHandler.h b/openmp/libomptarget/plugins-nextgen/common/include/GlobalHandler.h
index 5c767995126b77..829b4b72911935 100644
--- a/openmp/libomptarget/plugins-nextgen/common/include/GlobalHandler.h
+++ b/openmp/libomptarget/plugins-nextgen/common/include/GlobalHandler.h
@@ -104,7 +104,7 @@ class GenericGlobalHandlerTy {
   virtual ~GenericGlobalHandlerTy() {}
 
   /// Helper function for getting an ELF from a device image.
-  Expected<ELF64LEObjectFile> getELFObjectFile(DeviceImageTy &Image);
+  Expected<std::unique_ptr<ObjectFile>> getELFObjectFile(DeviceImageTy &Image);
 
   /// Returns whether the symbol named \p SymName is present in the given \p
   /// Image.

diff  --git a/openmp/libomptarget/plugins-nextgen/common/include/Utils/ELF.h b/openmp/libomptarget/plugins-nextgen/common/include/Utils/ELF.h
index 140a6b6b84aa12..88c83d39b68ceb 100644
--- a/openmp/libomptarget/plugins-nextgen/common/include/Utils/ELF.h
+++ b/openmp/libomptarget/plugins-nextgen/common/include/Utils/ELF.h
@@ -28,17 +28,15 @@ bool isELF(llvm::StringRef Buffer);
 llvm::Expected<bool> checkMachine(llvm::StringRef Object, uint16_t EMachine);
 
 /// Returns a pointer to the given \p Symbol inside of an ELF object.
-llvm::Expected<const void *> getSymbolAddress(
-    const llvm::object::ELFObjectFile<llvm::object::ELF64LE> &ELFObj,
-    const llvm::object::ELF64LE::Sym &Symbol);
+llvm::Expected<const void *>
+getSymbolAddress(const llvm::object::ELFSymbolRef &Symbol);
 
 /// Returns the symbol associated with the \p Name in the \p ELFObj. It will
 /// first search for the hash sections to identify symbols from the hash table.
 /// If that fails it will fall back to a linear search in the case of an
 /// executable file without a hash table.
-llvm::Expected<const typename llvm::object::ELF64LE::Sym *>
-getSymbol(const llvm::object::ELFObjectFile<llvm::object::ELF64LE> &ELFObj,
-          llvm::StringRef Name);
+llvm::Expected<std::optional<llvm::object::ELFSymbolRef>>
+getSymbol(const llvm::object::ObjectFile &ELFObj, llvm::StringRef Name);
 
 } // namespace elf
 } // namespace utils

diff  --git a/openmp/libomptarget/plugins-nextgen/common/src/GlobalHandler.cpp b/openmp/libomptarget/plugins-nextgen/common/src/GlobalHandler.cpp
index d398f60c55bd13..ba0aa47f8e51c3 100644
--- a/openmp/libomptarget/plugins-nextgen/common/src/GlobalHandler.cpp
+++ b/openmp/libomptarget/plugins-nextgen/common/src/GlobalHandler.cpp
@@ -25,14 +25,12 @@ using namespace omp;
 using namespace target;
 using namespace plugin;
 
-Expected<ELF64LEObjectFile>
+Expected<std::unique_ptr<ObjectFile>>
 GenericGlobalHandlerTy::getELFObjectFile(DeviceImageTy &Image) {
   assert(utils::elf::isELF(Image.getMemoryBuffer().getBuffer()) &&
          "Input is not an ELF file");
 
-  Expected<ELF64LEObjectFile> ElfOrErr =
-      ELF64LEObjectFile::create(Image.getMemoryBuffer());
-  return ElfOrErr;
+  return ELFObjectFileBase::createELFObjectFile(Image.getMemoryBuffer());
 }
 
 Error GenericGlobalHandlerTy::moveGlobalBetweenDeviceAndHost(
@@ -91,13 +89,13 @@ bool GenericGlobalHandlerTy::isSymbolInImage(GenericDeviceTy &Device,
   }
 
   // Search the ELF symbol using the symbol name.
-  auto SymOrErr = utils::elf::getSymbol(*ELFObjOrErr, SymName);
+  auto SymOrErr = utils::elf::getSymbol(**ELFObjOrErr, SymName);
   if (!SymOrErr) {
     consumeError(SymOrErr.takeError());
     return false;
   }
 
-  return *SymOrErr;
+  return SymOrErr->has_value();
 }
 
 Error GenericGlobalHandlerTy::getGlobalMetadataFromImage(
@@ -110,17 +108,17 @@ Error GenericGlobalHandlerTy::getGlobalMetadataFromImage(
     return ELFObj.takeError();
 
   // Search the ELF symbol using the symbol name.
-  auto SymOrErr = utils::elf::getSymbol(*ELFObj, ImageGlobal.getName());
+  auto SymOrErr = utils::elf::getSymbol(**ELFObj, ImageGlobal.getName());
   if (!SymOrErr)
     return Plugin::error("Failed ELF lookup of global '%s': %s",
                          ImageGlobal.getName().data(),
                          toString(SymOrErr.takeError()).data());
 
-  if (!*SymOrErr)
+  if (!SymOrErr->has_value())
     return Plugin::error("Failed to find global symbol '%s' in the ELF image",
                          ImageGlobal.getName().data());
 
-  auto AddrOrErr = utils::elf::getSymbolAddress(*ELFObj, **SymOrErr);
+  auto AddrOrErr = utils::elf::getSymbolAddress(**SymOrErr);
   // Get the section to which the symbol belongs.
   if (!AddrOrErr)
     return Plugin::error("Failed to get ELF symbol from global '%s': %s",
@@ -129,7 +127,7 @@ Error GenericGlobalHandlerTy::getGlobalMetadataFromImage(
 
   // Setup the global symbol's address and size.
   ImageGlobal.setPtr(const_cast<void *>(*AddrOrErr));
-  ImageGlobal.setSize((*SymOrErr)->st_size);
+  ImageGlobal.setSize((*SymOrErr)->getSize());
 
   return Plugin::success();
 }

diff  --git a/openmp/libomptarget/plugins-nextgen/common/src/Utils/ELF.cpp b/openmp/libomptarget/plugins-nextgen/common/src/Utils/ELF.cpp
index c84c3bad5def0a..2ae97f0f258925 100644
--- a/openmp/libomptarget/plugins-nextgen/common/src/Utils/ELF.cpp
+++ b/openmp/libomptarget/plugins-nextgen/common/src/Utils/ELF.cpp
@@ -36,18 +36,10 @@ bool utils::elf::isELF(StringRef Buffer) {
   }
 }
 
-Expected<bool> utils::elf::checkMachine(StringRef Object, uint16_t EMachine) {
-  assert(isELF(Object) && "Input is not an ELF!");
-
-  Expected<ELF64LEObjectFile> ElfOrErr =
-      ELF64LEObjectFile::create(MemoryBufferRef(Object, /*Identifier=*/""),
-                                /*InitContent=*/false);
-  if (!ElfOrErr)
-    return ElfOrErr.takeError();
-
-  const auto Header = ElfOrErr->getELFFile().getHeader();
-  if (Header.e_ident[EI_CLASS] != ELFCLASS64)
-    return createError("Only 64-bit ELF files are supported");
+template <class ELFT>
+static Expected<bool>
+checkMachineImpl(const object::ELFObjectFile<ELFT> &ELFObj, uint16_t EMachine) {
+  const auto Header = ELFObj.getELFFile().getHeader();
   if (Header.e_type != ET_EXEC && Header.e_type != ET_DYN)
     return createError("Only executable ELF files are supported");
 
@@ -71,6 +63,25 @@ Expected<bool> utils::elf::checkMachine(StringRef Object, uint16_t EMachine) {
   return Header.e_machine == EMachine;
 }
 
+Expected<bool> utils::elf::checkMachine(StringRef Object, uint16_t EMachine) {
+  assert(isELF(Object) && "Input is not an ELF!");
+
+  Expected<std::unique_ptr<ObjectFile>> ElfOrErr =
+      ObjectFile::createELFObjectFile(
+          MemoryBufferRef(Object, /*Identifier=*/""),
+          /*InitContent=*/false);
+  if (!ElfOrErr)
+    return ElfOrErr.takeError();
+
+  if (const ELF64LEObjectFile *ELFObj =
+          dyn_cast<ELF64LEObjectFile>(&**ElfOrErr))
+    return checkMachineImpl(*ELFObj, EMachine);
+  if (const ELF64BEObjectFile *ELFObj =
+          dyn_cast<ELF64BEObjectFile>(&**ElfOrErr))
+    return checkMachineImpl(*ELFObj, EMachine);
+  return createError("Only 64-bit ELF files are supported");
+}
+
 template <class ELFT>
 static Expected<const typename ELFT::Sym *>
 getSymbolFromGnuHashTable(StringRef Name, const typename ELFT::GnuHash &HashTab,
@@ -138,9 +149,10 @@ getSymbolFromSysVHashTable(StringRef Name, const typename ELFT::Hash &HashTab,
 }
 
 template <class ELFT>
-static Expected<const typename ELFT::Sym *>
-getHashTableSymbol(const ELFFile<ELFT> &Elf, const typename ELFT::Shdr &Sec,
-                   StringRef Name) {
+static Expected<std::optional<ELFSymbolRef>>
+getHashTableSymbol(const ELFObjectFile<ELFT> &ELFObj,
+                   const typename ELFT::Shdr &Sec, StringRef Name) {
+  const ELFFile<ELFT> &Elf = ELFObj.getELFFile();
   if (Sec.sh_type != ELF::SHT_HASH && Sec.sh_type != ELF::SHT_GNU_HASH)
     return createError(
         "invalid sh_type for hash table, expected SHT_HASH or SHT_GNU_HASH");
@@ -179,7 +191,12 @@ getHashTableSymbol(const ELFFile<ELFT> &Elf, const typename ELFT::Shdr &Sec,
                 sizeof(typename ELFT::Word) * HashTab->nbuckets +
                 sizeof(typename ELFT::Word) * (SymTab.size() - HashTab->symndx))
       return createError("section has invalid sh_size: " + Twine(Sec.sh_size));
-    return getSymbolFromGnuHashTable<ELFT>(Name, *HashTab, SymTab, StrTab);
+    auto Sym = getSymbolFromGnuHashTable<ELFT>(Name, *HashTab, SymTab, StrTab);
+    if (!Sym)
+      return Sym.takeError();
+    if (!*Sym)
+      return std::nullopt;
+    return ELFObj.toSymbolRef(*SymTabOrErr, *Sym - &SymTab[0]);
   }
 
   // If this is a Sys-V hash table we verify its size and search the symbol
@@ -197,16 +214,22 @@ getHashTableSymbol(const ELFFile<ELFT> &Elf, const typename ELFT::Shdr &Sec,
                           sizeof(typename ELFT::Word) * HashTab->nchain)
       return createError("section has invalid sh_size: " + Twine(Sec.sh_size));
 
-    return getSymbolFromSysVHashTable<ELFT>(Name, *HashTab, SymTab, StrTab);
+    auto Sym = getSymbolFromSysVHashTable<ELFT>(Name, *HashTab, SymTab, StrTab);
+    if (!Sym)
+      return Sym.takeError();
+    if (!*Sym)
+      return std::nullopt;
+    return ELFObj.toSymbolRef(*SymTabOrErr, *Sym - &SymTab[0]);
   }
 
-  return nullptr;
+  return std::nullopt;
 }
 
 template <class ELFT>
-static Expected<const typename ELFT::Sym *>
-getSymTableSymbol(const ELFFile<ELFT> &Elf, const typename ELFT::Shdr &Sec,
-                  StringRef Name) {
+static Expected<std::optional<ELFSymbolRef>>
+getSymTableSymbol(const ELFObjectFile<ELFT> &ELFObj,
+                  const typename ELFT::Shdr &Sec, StringRef Name) {
+  const ELFFile<ELFT> &Elf = ELFObj.getELFFile();
   if (Sec.sh_type != ELF::SHT_SYMTAB && Sec.sh_type != ELF::SHT_DYNSYM)
     return createError(
         "invalid sh_type for hash table, expected SHT_SYMTAB or SHT_DYNSYM");
@@ -226,13 +249,14 @@ getSymTableSymbol(const ELFFile<ELFT> &Elf, const typename ELFT::Shdr &Sec,
 
   for (const typename ELFT::Sym &Sym : SymTab)
     if (StrTab.drop_front(Sym.st_name).data() == Name)
-      return &Sym;
+      return ELFObj.toSymbolRef(&Sec, &Sym - &SymTab[0]);
 
-  return nullptr;
+  return std::nullopt;
 }
 
-Expected<const typename ELF64LE::Sym *>
-utils::elf::getSymbol(const ELFObjectFile<ELF64LE> &ELFObj, StringRef Name) {
+template <class ELFT>
+static Expected<std::optional<ELFSymbolRef>>
+getSymbolImpl(const ELFObjectFile<ELFT> &ELFObj, StringRef Name) {
   // First try to look up the symbol via the hash table.
   for (ELFSectionRef Sec : ELFObj.sections()) {
     if (Sec.getType() != SHT_HASH && Sec.getType() != SHT_GNU_HASH)
@@ -241,8 +265,7 @@ utils::elf::getSymbol(const ELFObjectFile<ELF64LE> &ELFObj, StringRef Name) {
     auto HashTabOrErr = ELFObj.getELFFile().getSection(Sec.getIndex());
     if (!HashTabOrErr)
       return HashTabOrErr.takeError();
-    return getHashTableSymbol<ELF64LE>(ELFObj.getELFFile(), **HashTabOrErr,
-                                       Name);
+    return getHashTableSymbol<ELFT>(ELFObj, **HashTabOrErr, Name);
   }
 
   // If this is an executable file check the entire standard symbol table.
@@ -253,16 +276,31 @@ utils::elf::getSymbol(const ELFObjectFile<ELF64LE> &ELFObj, StringRef Name) {
     auto SymTabOrErr = ELFObj.getELFFile().getSection(Sec.getIndex());
     if (!SymTabOrErr)
       return SymTabOrErr.takeError();
-    return getSymTableSymbol<ELF64LE>(ELFObj.getELFFile(), **SymTabOrErr, Name);
+    return getSymTableSymbol<ELFT>(ELFObj, **SymTabOrErr, Name);
   }
 
-  return nullptr;
+  return std::nullopt;
 }
 
-Expected<const void *> utils::elf::getSymbolAddress(
-    const object::ELFObjectFile<object::ELF64LE> &ELFObj,
-    const object::ELF64LE::Sym &Symbol) {
-  const ELFFile<ELF64LE> &ELFFile = ELFObj.getELFFile();
+Expected<std::optional<ELFSymbolRef>>
+utils::elf::getSymbol(const ObjectFile &Obj, StringRef Name) {
+  if (const ELF64LEObjectFile *ELFObj = dyn_cast<ELF64LEObjectFile>(&Obj))
+    return getSymbolImpl(*ELFObj, Name);
+  if (const ELF64BEObjectFile *ELFObj = dyn_cast<ELF64BEObjectFile>(&Obj))
+    return getSymbolImpl(*ELFObj, Name);
+  return createError("Only 64-bit ELF files are supported");
+}
+
+template <class ELFT>
+static Expected<const void *>
+getSymbolAddressImpl(const ELFObjectFile<ELFT> &ELFObj,
+                     const ELFSymbolRef &SymRef) {
+  const ELFFile<ELFT> &ELFFile = ELFObj.getELFFile();
+
+  auto SymOrErr = ELFObj.getSymbol(SymRef.getRawDataRefImpl());
+  if (!SymOrErr)
+    return SymOrErr.takeError();
+  const auto &Symbol = **SymOrErr;
 
   auto SecOrErr = ELFFile.getSection(Symbol.st_shndx);
   if (!SecOrErr)
@@ -283,3 +321,13 @@ Expected<const void *> utils::elf::getSymbolAddress(
 
   return ELFFile.base() + Offset;
 }
+
+Expected<const void *>
+utils::elf::getSymbolAddress(const ELFSymbolRef &SymRef) {
+  const ObjectFile *Obj = SymRef.getObject();
+  if (const ELF64LEObjectFile *ELFObj = dyn_cast<ELF64LEObjectFile>(Obj))
+    return getSymbolAddressImpl(*ELFObj, SymRef);
+  if (const ELF64BEObjectFile *ELFObj = dyn_cast<ELF64BEObjectFile>(Obj))
+    return getSymbolAddressImpl(*ELFObj, SymRef);
+  return createError("Only 64-bit ELF files are supported");
+}

diff  --git a/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp b/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
index f85a00cd1cd530..b862bc74909257 100644
--- a/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
+++ b/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
@@ -1166,7 +1166,7 @@ struct CUDADeviceTy : public GenericDeviceTy {
 
     // Search for all symbols that contain a constructor or destructor.
     SmallVector<std::pair<StringRef, uint16_t>> Funcs;
-    for (ELFSymbolRef Sym : ELFObjOrErr->symbols()) {
+    for (ELFSymbolRef Sym : (*ELFObjOrErr)->symbols()) {
       auto NameOrErr = Sym.getName();
       if (!NameOrErr)
         return NameOrErr.takeError();


        


More information about the Openmp-commits mailing list