[llvm] b56df19 - [BitcodeReader] Allow reading pointer types from old IR

Sebastian Neubauer via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 17 04:19:59 PST 2023


Author: Sebastian Neubauer
Date: 2023-01-17T13:19:40+01:00
New Revision: b56df190b01335506ce30a4559d880da76d1a181

URL: https://github.com/llvm/llvm-project/commit/b56df190b01335506ce30a4559d880da76d1a181
DIFF: https://github.com/llvm/llvm-project/commit/b56df190b01335506ce30a4559d880da76d1a181.diff

LOG: [BitcodeReader] Allow reading pointer types from old IR

When opaque pointers are enabled and old IR with typed pointers is read,
the BitcodeReader automatically upgrades all typed pointers to opaque
pointers. This is a lossy conversion, i.e. when a function argument is a
pointer and unused, it’s impossible to reconstruct the original type
behind the pointer.

There are cases where the type information of pointers is needed. One is
reading DXIL, which is bitcode of old LLVM IR and makes a lot of use of
pointers in function signatures.
We’d like to keep using up-to-date llvm to read in and process DXIL, so
in the face of opaque pointers, we need some way to access the type
information of pointers from the read bitcode.

This patch allows extracting type information by supplying functions to
parseBitcodeFile that get called for each function signature or metadata
value. The function can access the type information via the reader’s
type IDs and the getTypeByID and getContainedTypeID functions.
The tests exemplarily shows how type info from pointers can be stored in
metadata for use after the BitcodeReader finished.

Differential Revision: https://reviews.llvm.org/D127728

Added: 
    

Modified: 
    llvm/include/llvm/Bitcode/BitcodeReader.h
    llvm/include/llvm/IRReader/IRReader.h
    llvm/lib/Bitcode/Reader/BitcodeReader.cpp
    llvm/lib/Bitcode/Reader/MetadataLoader.cpp
    llvm/lib/Bitcode/Reader/MetadataLoader.h
    llvm/lib/IRReader/IRReader.cpp
    llvm/tools/llc/llc.cpp
    llvm/tools/opt/opt.cpp
    llvm/unittests/Bitcode/BitReaderTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Bitcode/BitcodeReader.h b/llvm/include/llvm/Bitcode/BitcodeReader.h
index a4b300d019a3a..5f87445eff1df 100644
--- a/llvm/include/llvm/Bitcode/BitcodeReader.h
+++ b/llvm/include/llvm/Bitcode/BitcodeReader.h
@@ -32,27 +32,60 @@ namespace llvm {
 class LLVMContext;
 class Module;
 class MemoryBuffer;
+class Metadata;
 class ModuleSummaryIndex;
+class Type;
+class Value;
 
 // Callback to override the data layout string of an imported bitcode module.
 // The first argument is the target triple, the second argument the data layout
 // string from the input, or a default string. It will be used if the callback
 // returns std::nullopt.
-typedef llvm::function_ref<std::optional<std::string>(StringRef, StringRef)>
-    DataLayoutCallbackTy;
-
-  // These functions are for converting Expected/Error values to
-  // ErrorOr/std::error_code for compatibility with legacy clients. FIXME:
-  // Remove these functions once no longer needed by the C and libLTO APIs.
-
-  std::error_code errorToErrorCodeAndEmitErrors(LLVMContext &Ctx, Error Err);
-
-  template <typename T>
-  ErrorOr<T> expectedToErrorOrAndEmitErrors(LLVMContext &Ctx, Expected<T> Val) {
-    if (!Val)
-      return errorToErrorCodeAndEmitErrors(Ctx, Val.takeError());
-    return std::move(*Val);
-  }
+typedef std::function<std::optional<std::string>(StringRef, StringRef)>
+    DataLayoutCallbackFuncTy;
+
+typedef std::function<Type *(unsigned)> GetTypeByIDTy;
+
+typedef std::function<unsigned(unsigned, unsigned)> GetContainedTypeIDTy;
+
+typedef std::function<void(Value *, unsigned, GetTypeByIDTy,
+                           GetContainedTypeIDTy)>
+    ValueTypeCallbackTy;
+
+typedef std::function<void(Metadata **, unsigned, GetTypeByIDTy,
+                           GetContainedTypeIDTy)>
+    MDTypeCallbackTy;
+
+// These functions are for converting Expected/Error values to
+// ErrorOr/std::error_code for compatibility with legacy clients. FIXME:
+// Remove these functions once no longer needed by the C and libLTO APIs.
+
+std::error_code errorToErrorCodeAndEmitErrors(LLVMContext &Ctx, Error Err);
+
+template <typename T>
+ErrorOr<T> expectedToErrorOrAndEmitErrors(LLVMContext &Ctx, Expected<T> Val) {
+  if (!Val)
+    return errorToErrorCodeAndEmitErrors(Ctx, Val.takeError());
+  return std::move(*Val);
+}
+
+struct ParserCallbacks {
+  std::optional<DataLayoutCallbackFuncTy> DataLayout;
+  /// The ValueType callback is called for every function definition or
+  /// declaration and allows accessing the type information, also behind
+  /// pointers. This can be useful, when the opaque pointer upgrade cleans all
+  /// type information behind pointers.
+  /// The second argument to ValueTypeCallback is the type ID of the
+  /// function, the two passed functions can be used to extract type
+  /// information.
+  std::optional<ValueTypeCallbackTy> ValueType;
+  /// The MDType callback is called for every value in metadata.
+  std::optional<MDTypeCallbackTy> MDType;
+
+  ParserCallbacks() = default;
+  explicit ParserCallbacks(DataLayoutCallbackFuncTy DataLayout)
+      : DataLayout(DataLayout) {}
+};
 
   struct BitcodeFileContents;
 
@@ -90,7 +123,7 @@ typedef llvm::function_ref<std::optional<std::string>(StringRef, StringRef)>
     Expected<std::unique_ptr<Module>>
     getModuleImpl(LLVMContext &Context, bool MaterializeAll,
                   bool ShouldLazyLoadMetadata, bool IsImporting,
-                  DataLayoutCallbackTy DataLayoutCallback);
+                  ParserCallbacks Callbacks = {});
 
   public:
     StringRef getBuffer() const {
@@ -105,18 +138,13 @@ typedef llvm::function_ref<std::optional<std::string>(StringRef, StringRef)>
     /// bodies. If ShouldLazyLoadMetadata is true, lazily load metadata as well.
     /// If IsImporting is true, this module is being parsed for ThinLTO
     /// importing into another module.
-    Expected<std::unique_ptr<Module>> getLazyModule(
-        LLVMContext &Context, bool ShouldLazyLoadMetadata, bool IsImporting,
-        DataLayoutCallbackTy DataLayoutCallback = [](StringRef, StringRef) {
-          return std::nullopt;
-        });
+    Expected<std::unique_ptr<Module>>
+    getLazyModule(LLVMContext &Context, bool ShouldLazyLoadMetadata,
+                  bool IsImporting, ParserCallbacks Callbacks = {});
 
     /// Read the entire bitcode module and return it.
-    Expected<std::unique_ptr<Module>> parseModule(
-        LLVMContext &Context,
-        DataLayoutCallbackTy DataLayoutCallback = [](StringRef, StringRef) {
-          return std::nullopt;
-        });
+    Expected<std::unique_ptr<Module>>
+    parseModule(LLVMContext &Context, ParserCallbacks Callbacks = {});
 
     /// Returns information about the module to be used for LTO: whether to
     /// compile with ThinLTO, and whether it has a summary.
@@ -153,12 +181,11 @@ typedef llvm::function_ref<std::optional<std::string>(StringRef, StringRef)>
   /// deserialization of function bodies. If ShouldLazyLoadMetadata is true,
   /// lazily load metadata as well. If IsImporting is true, this module is
   /// being parsed for ThinLTO importing into another module.
-  Expected<std::unique_ptr<Module>> getLazyBitcodeModule(
-      MemoryBufferRef Buffer, LLVMContext &Context,
-      bool ShouldLazyLoadMetadata = false, bool IsImporting = false,
-      DataLayoutCallbackTy DataLayoutCallback = [](StringRef, StringRef) {
-        return std::nullopt;
-      });
+  Expected<std::unique_ptr<Module>>
+  getLazyBitcodeModule(MemoryBufferRef Buffer, LLVMContext &Context,
+                       bool ShouldLazyLoadMetadata = false,
+                       bool IsImporting = false,
+                       ParserCallbacks Callbacks = {});
 
   /// Like getLazyBitcodeModule, except that the module takes ownership of
   /// the memory buffer if successful. If successful, this moves Buffer. On
@@ -166,7 +193,8 @@ typedef llvm::function_ref<std::optional<std::string>(StringRef, StringRef)>
   /// being parsed for ThinLTO importing into another module.
   Expected<std::unique_ptr<Module>> getOwningLazyBitcodeModule(
       std::unique_ptr<MemoryBuffer> &&Buffer, LLVMContext &Context,
-      bool ShouldLazyLoadMetadata = false, bool IsImporting = false);
+      bool ShouldLazyLoadMetadata = false, bool IsImporting = false,
+      ParserCallbacks Callbacks = {});
 
   /// Read the header of the specified bitcode buffer and extract just the
   /// triple information. If successful, this returns a string. On error, this
@@ -183,11 +211,9 @@ typedef llvm::function_ref<std::optional<std::string>(StringRef, StringRef)>
   Expected<std::string> getBitcodeProducerString(MemoryBufferRef Buffer);
 
   /// Read the specified bitcode file, returning the module.
-  Expected<std::unique_ptr<Module>> parseBitcodeFile(
-      MemoryBufferRef Buffer, LLVMContext &Context,
-      DataLayoutCallbackTy DataLayoutCallback = [](StringRef, StringRef) {
-        return std::nullopt;
-      });
+  Expected<std::unique_ptr<Module>>
+  parseBitcodeFile(MemoryBufferRef Buffer, LLVMContext &Context,
+                   ParserCallbacks Callbacks = {});
 
   /// Returns LTO information for the specified bitcode file.
   Expected<BitcodeLTOInfo> getBitcodeLTOInfo(MemoryBufferRef Buffer);

diff  --git a/llvm/include/llvm/IRReader/IRReader.h b/llvm/include/llvm/IRReader/IRReader.h
index 4eabf473ce53d..644fea82bfbe0 100644
--- a/llvm/include/llvm/IRReader/IRReader.h
+++ b/llvm/include/llvm/IRReader/IRReader.h
@@ -16,6 +16,7 @@
 
 #include "llvm/ADT/STLFunctionalExtras.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Bitcode/BitcodeReader.h"
 #include <memory>
 #include <optional>
 
@@ -27,9 +28,6 @@ class Module;
 class SMDiagnostic;
 class LLVMContext;
 
-typedef llvm::function_ref<std::optional<std::string>(StringRef, StringRef)>
-    DataLayoutCallbackTy;
-
 /// If the given MemoryBuffer holds a bitcode image, return a Module
 /// for it which does lazy deserialization of function bodies.  Otherwise,
 /// attempt to parse it as LLVM Assembly and return a fully populated
@@ -53,21 +51,17 @@ getLazyIRFileModule(StringRef Filename, SMDiagnostic &Err, LLVMContext &Context,
 /// for it.  Otherwise, attempt to parse it as LLVM Assembly and return
 /// a Module for it.
 /// \param DataLayoutCallback Override datalayout in the llvm assembly.
-std::unique_ptr<Module> parseIR(
-    MemoryBufferRef Buffer, SMDiagnostic &Err, LLVMContext &Context,
-    DataLayoutCallbackTy DataLayoutCallback = [](StringRef, StringRef) {
-      return std::nullopt;
-    });
+std::unique_ptr<Module> parseIR(MemoryBufferRef Buffer, SMDiagnostic &Err,
+                                LLVMContext &Context,
+                                ParserCallbacks Callbacks = {});
 
 /// If the given file holds a bitcode image, return a Module for it.
 /// Otherwise, attempt to parse it as LLVM Assembly and return a Module
 /// for it.
 /// \param DataLayoutCallback Override datalayout in the llvm assembly.
-std::unique_ptr<Module> parseIRFile(
-    StringRef Filename, SMDiagnostic &Err, LLVMContext &Context,
-    DataLayoutCallbackTy DataLayoutCallback = [](StringRef, StringRef) {
-      return std::nullopt;
-    });
+std::unique_ptr<Module> parseIRFile(StringRef Filename, SMDiagnostic &Err,
+                                    LLVMContext &Context,
+                                    ParserCallbacks Callbacks = {});
 }
 
 #endif

diff  --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index aa33006dd5963..0a346e342de22 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -674,6 +674,8 @@ class BitcodeReader : public BitcodeReaderBase, public GVMaterializer {
   std::vector<std::string> BundleTags;
   SmallVector<SyncScope::ID, 8> SSIDs;
 
+  std::optional<ValueTypeCallbackTy> ValueTypeCallback;
+
 public:
   BitcodeReader(BitstreamCursor Stream, StringRef Strtab,
                 StringRef ProducerIdentification, LLVMContext &Context);
@@ -686,9 +688,8 @@ class BitcodeReader : public BitcodeReaderBase, public GVMaterializer {
 
   /// Main interface to parsing a bitcode buffer.
   /// \returns true if an error occurred.
-  Error parseBitcodeInto(
-      Module *M, bool ShouldLazyLoadMetadata, bool IsImporting,
-      DataLayoutCallbackTy DataLayoutCallback);
+  Error parseBitcodeInto(Module *M, bool ShouldLazyLoadMetadata,
+                         bool IsImporting, ParserCallbacks Callbacks = {});
 
   static uint64_t decodeSignRotatedValue(uint64_t V);
 
@@ -709,6 +710,7 @@ class BitcodeReader : public BitcodeReaderBase, public GVMaterializer {
   unsigned getContainedTypeID(unsigned ID, unsigned Idx = 0);
   unsigned getVirtualTypeID(Type *Ty, ArrayRef<unsigned> ContainedTypeIDs = {});
 
+  void callValueTypeCallback(Value *F, unsigned TypeID);
   Expected<Value *> materializeValue(unsigned ValID, BasicBlock *InsertBB);
   Expected<Constant *> getValueForInitializer(unsigned ID);
 
@@ -819,11 +821,8 @@ class BitcodeReader : public BitcodeReaderBase, public GVMaterializer {
   /// a corresponding error code.
   Error parseAlignmentValue(uint64_t Exponent, MaybeAlign &Alignment);
   Error parseAttrKind(uint64_t Code, Attribute::AttrKind *Kind);
-  Error parseModule(
-      uint64_t ResumeBit, bool ShouldLazyLoadMetadata = false,
-      DataLayoutCallbackTy DataLayoutCallback = [](StringRef, StringRef) {
-        return std::nullopt;
-      });
+  Error parseModule(uint64_t ResumeBit, bool ShouldLazyLoadMetadata = false,
+                    ParserCallbacks Callbacks = {});
 
   Error parseComdatRecord(ArrayRef<uint64_t> Record);
   Error parseGlobalVarRecord(ArrayRef<uint64_t> Record);
@@ -3919,6 +3918,14 @@ Error BitcodeReader::parseGlobalVarRecord(ArrayRef<uint64_t> Record) {
   return Error::success();
 }
 
+void BitcodeReader::callValueTypeCallback(Value *F, unsigned TypeID) {
+  if (ValueTypeCallback) {
+    (*ValueTypeCallback)(
+        F, TypeID, [this](unsigned I) { return getTypeByID(I); },
+        [this](unsigned I, unsigned J) { return getContainedTypeID(I, J); });
+  }
+}
+
 Error BitcodeReader::parseFunctionRecord(ArrayRef<uint64_t> Record) {
   // v1: [type, callingconv, isproto, linkage, paramattr, alignment, section,
   // visibility, gc, unnamed_addr, prologuedata, dllstorageclass, comdat,
@@ -3963,6 +3970,7 @@ Error BitcodeReader::parseFunctionRecord(ArrayRef<uint64_t> Record) {
   uint64_t RawLinkage = Record[3];
   Func->setLinkage(getDecodedLinkage(RawLinkage));
   Func->setAttributes(getAttributes(Record[4]));
+  callValueTypeCallback(Func, FTyID);
 
   // Upgrade any old-style byval or sret without a type by propagating the
   // argument's pointee type. There should be no opaque pointers where the byval
@@ -4180,7 +4188,8 @@ Error BitcodeReader::parseGlobalIndirectSymbolRecord(
 
 Error BitcodeReader::parseModule(uint64_t ResumeBit,
                                  bool ShouldLazyLoadMetadata,
-                                 DataLayoutCallbackTy DataLayoutCallback) {
+                                 ParserCallbacks Callbacks) {
+  this->ValueTypeCallback = std::move(Callbacks.ValueType);
   if (ResumeBit) {
     if (Error JumpFailed = Stream.JumpToBit(ResumeBit))
       return JumpFailed;
@@ -4210,9 +4219,11 @@ Error BitcodeReader::parseModule(uint64_t ResumeBit,
         TentativeDataLayoutStr, TheModule->getTargetTriple());
 
     // Apply override
-    if (auto LayoutOverride = DataLayoutCallback(TheModule->getTargetTriple(),
-                                                 TentativeDataLayoutStr))
-      TentativeDataLayoutStr = *LayoutOverride;
+    if (Callbacks.DataLayout) {
+      if (auto LayoutOverride = (*Callbacks.DataLayout)(
+              TheModule->getTargetTriple(), TentativeDataLayoutStr))
+        TentativeDataLayoutStr = *LayoutOverride;
+    }
 
     // Now the layout string is finalized in TentativeDataLayoutStr. Parse it.
     Expected<DataLayout> MaybeDL = DataLayout::parse(TentativeDataLayoutStr);
@@ -4477,16 +4488,22 @@ Error BitcodeReader::parseModule(uint64_t ResumeBit,
     }
     Record.clear();
   }
+  this->ValueTypeCallback = std::nullopt;
   return Error::success();
 }
 
 Error BitcodeReader::parseBitcodeInto(Module *M, bool ShouldLazyLoadMetadata,
                                       bool IsImporting,
-                                      DataLayoutCallbackTy DataLayoutCallback) {
+                                      ParserCallbacks Callbacks) {
   TheModule = M;
-  MDLoader = MetadataLoader(Stream, *M, ValueList, IsImporting,
-                            [&](unsigned ID) { return getTypeByID(ID); });
-  return parseModule(0, ShouldLazyLoadMetadata, DataLayoutCallback);
+  MetadataLoaderCallbacks MDCallbacks;
+  MDCallbacks.GetTypeByID = [&](unsigned ID) { return getTypeByID(ID); };
+  MDCallbacks.GetContainedTypeID = [&](unsigned I, unsigned J) {
+    return getContainedTypeID(I, J);
+  };
+  MDCallbacks.MDType = Callbacks.MDType;
+  MDLoader = MetadataLoader(Stream, *M, ValueList, IsImporting, MDCallbacks);
+  return parseModule(0, ShouldLazyLoadMetadata, Callbacks);
 }
 
 Error BitcodeReader::typeCheckLoadStoreInst(Type *ValType, Type *PtrType) {
@@ -7919,7 +7936,7 @@ llvm::getBitcodeFileContents(MemoryBufferRef Buffer) {
 Expected<std::unique_ptr<Module>>
 BitcodeModule::getModuleImpl(LLVMContext &Context, bool MaterializeAll,
                              bool ShouldLazyLoadMetadata, bool IsImporting,
-                             DataLayoutCallbackTy DataLayoutCallback) {
+                             ParserCallbacks Callbacks) {
   BitstreamCursor Stream(Buffer);
 
   std::string ProducerIdentification;
@@ -7942,7 +7959,7 @@ BitcodeModule::getModuleImpl(LLVMContext &Context, bool MaterializeAll,
 
   // Delay parsing Metadata if ShouldLazyLoadMetadata is true.
   if (Error Err = R->parseBitcodeInto(M.get(), ShouldLazyLoadMetadata,
-                                      IsImporting, DataLayoutCallback))
+                                      IsImporting, Callbacks))
     return std::move(Err);
 
   if (MaterializeAll) {
@@ -7959,10 +7976,9 @@ BitcodeModule::getModuleImpl(LLVMContext &Context, bool MaterializeAll,
 
 Expected<std::unique_ptr<Module>>
 BitcodeModule::getLazyModule(LLVMContext &Context, bool ShouldLazyLoadMetadata,
-                             bool IsImporting,
-                             DataLayoutCallbackTy DataLayoutCallback) {
+                             bool IsImporting, ParserCallbacks Callbacks) {
   return getModuleImpl(Context, false, ShouldLazyLoadMetadata, IsImporting,
-                       DataLayoutCallback);
+                       Callbacks);
 }
 
 // Parse the specified bitcode buffer and merge the index into CombinedIndex.
@@ -8109,41 +8125,40 @@ static Expected<BitcodeModule> getSingleModule(MemoryBufferRef Buffer) {
 Expected<std::unique_ptr<Module>>
 llvm::getLazyBitcodeModule(MemoryBufferRef Buffer, LLVMContext &Context,
                            bool ShouldLazyLoadMetadata, bool IsImporting,
-                           DataLayoutCallbackTy DataLayoutCallback) {
+                           ParserCallbacks Callbacks) {
   Expected<BitcodeModule> BM = getSingleModule(Buffer);
   if (!BM)
     return BM.takeError();
 
   return BM->getLazyModule(Context, ShouldLazyLoadMetadata, IsImporting,
-                           DataLayoutCallback);
+                           Callbacks);
 }
 
 Expected<std::unique_ptr<Module>> llvm::getOwningLazyBitcodeModule(
     std::unique_ptr<MemoryBuffer> &&Buffer, LLVMContext &Context,
-    bool ShouldLazyLoadMetadata, bool IsImporting) {
+    bool ShouldLazyLoadMetadata, bool IsImporting, ParserCallbacks Callbacks) {
   auto MOrErr = getLazyBitcodeModule(*Buffer, Context, ShouldLazyLoadMetadata,
-                                     IsImporting);
+                                     IsImporting, Callbacks);
   if (MOrErr)
     (*MOrErr)->setOwnedMemoryBuffer(std::move(Buffer));
   return MOrErr;
 }
 
 Expected<std::unique_ptr<Module>>
-BitcodeModule::parseModule(LLVMContext &Context,
-                           DataLayoutCallbackTy DataLayoutCallback) {
-  return getModuleImpl(Context, true, false, false, DataLayoutCallback);
+BitcodeModule::parseModule(LLVMContext &Context, ParserCallbacks Callbacks) {
+  return getModuleImpl(Context, true, false, false, Callbacks);
   // TODO: Restore the use-lists to the in-memory state when the bitcode was
   // written.  We must defer until the Module has been fully materialized.
 }
 
 Expected<std::unique_ptr<Module>>
 llvm::parseBitcodeFile(MemoryBufferRef Buffer, LLVMContext &Context,
-                       DataLayoutCallbackTy DataLayoutCallback) {
+                       ParserCallbacks Callbacks) {
   Expected<BitcodeModule> BM = getSingleModule(Buffer);
   if (!BM)
     return BM.takeError();
 
-  return BM->parseModule(Context, DataLayoutCallback);
+  return BM->parseModule(Context, Callbacks);
 }
 
 Expected<std::string> llvm::getBitcodeTargetTriple(MemoryBufferRef Buffer) {

diff  --git a/llvm/lib/Bitcode/Reader/MetadataLoader.cpp b/llvm/lib/Bitcode/Reader/MetadataLoader.cpp
index fc452c2c69c57..4b5cfedaa99c1 100644
--- a/llvm/lib/Bitcode/Reader/MetadataLoader.cpp
+++ b/llvm/lib/Bitcode/Reader/MetadataLoader.cpp
@@ -406,7 +406,7 @@ class MetadataLoader::MetadataLoaderImpl {
   BitstreamCursor &Stream;
   LLVMContext &Context;
   Module &TheModule;
-  std::function<Type *(unsigned)> getTypeByID;
+  MetadataLoaderCallbacks Callbacks;
 
   /// Cursor associated with the lazy-loading of Metadata. This is the easy way
   /// to keep around the right "context" (Abbrev list) to be able to jump in
@@ -627,14 +627,15 @@ class MetadataLoader::MetadataLoaderImpl {
     upgradeCUVariables();
   }
 
+  void callMDTypeCallback(Metadata **Val, unsigned TypeID);
+
 public:
   MetadataLoaderImpl(BitstreamCursor &Stream, Module &TheModule,
                      BitcodeReaderValueList &ValueList,
-                     std::function<Type *(unsigned)> getTypeByID,
-                     bool IsImporting)
+                     MetadataLoaderCallbacks Callbacks, bool IsImporting)
       : MetadataList(TheModule.getContext(), Stream.SizeInBytes()),
         ValueList(ValueList), Stream(Stream), Context(TheModule.getContext()),
-        TheModule(TheModule), getTypeByID(std::move(getTypeByID)),
+        TheModule(TheModule), Callbacks(std::move(Callbacks)),
         IsImporting(IsImporting) {}
 
   Error parseMetadata(bool ModuleLevel);
@@ -952,6 +953,14 @@ Expected<bool> MetadataLoader::MetadataLoaderImpl::loadGlobalDeclAttachments() {
   }
 }
 
+void MetadataLoader::MetadataLoaderImpl::callMDTypeCallback(Metadata **Val,
+                                                            unsigned TypeID) {
+  if (Callbacks.MDType) {
+    (*Callbacks.MDType)(Val, TypeID, Callbacks.GetTypeByID,
+                        Callbacks.GetContainedTypeID);
+  }
+}
+
 /// Parse a METADATA_BLOCK. If ModuleLevel is true then we are parsing
 /// module level metadata.
 Error MetadataLoader::MetadataLoaderImpl::parseMetadata(bool ModuleLevel) {
@@ -1221,7 +1230,7 @@ Error MetadataLoader::MetadataLoaderImpl::parseOneMetadata(
     }
 
     unsigned TyID = Record[0];
-    Type *Ty = getTypeByID(TyID);
+    Type *Ty = Callbacks.GetTypeByID(TyID);
     if (Ty->isMetadataTy() || Ty->isVoidTy()) {
       dropRecord();
       break;
@@ -1245,7 +1254,7 @@ Error MetadataLoader::MetadataLoaderImpl::parseOneMetadata(
     SmallVector<Metadata *, 8> Elts;
     for (unsigned i = 0; i != Size; i += 2) {
       unsigned TyID = Record[i];
-      Type *Ty = getTypeByID(TyID);
+      Type *Ty = Callbacks.GetTypeByID(TyID);
       if (!Ty)
         return error("Invalid record");
       if (Ty->isMetadataTy())
@@ -1255,9 +1264,10 @@ Error MetadataLoader::MetadataLoaderImpl::parseOneMetadata(
                                             /*ConstExprInsertBB*/ nullptr);
         if (!V)
           return error("Invalid value reference from old metadata");
-        auto *MD = ValueAsMetadata::get(V);
+        Metadata *MD = ValueAsMetadata::get(V);
         assert(isa<ConstantAsMetadata>(MD) &&
                "Expected non-function-local metadata");
+        callMDTypeCallback(&MD, TyID);
         Elts.push_back(MD);
       } else
         Elts.push_back(nullptr);
@@ -1271,7 +1281,7 @@ Error MetadataLoader::MetadataLoaderImpl::parseOneMetadata(
       return error("Invalid record");
 
     unsigned TyID = Record[0];
-    Type *Ty = getTypeByID(TyID);
+    Type *Ty = Callbacks.GetTypeByID(TyID);
     if (Ty->isMetadataTy() || Ty->isVoidTy())
       return error("Invalid record");
 
@@ -1280,7 +1290,9 @@ Error MetadataLoader::MetadataLoaderImpl::parseOneMetadata(
     if (!V)
       return error("Invalid value reference from metadata");
 
-    MetadataList.assignValue(ValueAsMetadata::get(V), NextMetadataNo);
+    Metadata *MD = ValueAsMetadata::get(V);
+    callMDTypeCallback(&MD, TyID);
+    MetadataList.assignValue(MD, NextMetadataNo);
     NextMetadataNo++;
     break;
   }
@@ -2359,9 +2371,9 @@ MetadataLoader::~MetadataLoader() = default;
 MetadataLoader::MetadataLoader(BitstreamCursor &Stream, Module &TheModule,
                                BitcodeReaderValueList &ValueList,
                                bool IsImporting,
-                               std::function<Type *(unsigned)> getTypeByID)
+                               MetadataLoaderCallbacks Callbacks)
     : Pimpl(std::make_unique<MetadataLoaderImpl>(
-          Stream, TheModule, ValueList, std::move(getTypeByID), IsImporting)) {}
+          Stream, TheModule, ValueList, std::move(Callbacks), IsImporting)) {}
 
 Error MetadataLoader::parseMetadata(bool ModuleLevel) {
   return Pimpl->parseMetadata(ModuleLevel);

diff  --git a/llvm/lib/Bitcode/Reader/MetadataLoader.h b/llvm/lib/Bitcode/Reader/MetadataLoader.h
index 653f1402bead3..fbee7e49f8dff 100644
--- a/llvm/lib/Bitcode/Reader/MetadataLoader.h
+++ b/llvm/lib/Bitcode/Reader/MetadataLoader.h
@@ -29,6 +29,20 @@ class Module;
 class Type;
 template <typename T> class ArrayRef;
 
+typedef std::function<Type *(unsigned)> GetTypeByIDTy;
+
+typedef std::function<unsigned(unsigned, unsigned)> GetContainedTypeIDTy;
+
+typedef std::function<void(Metadata **, unsigned, GetTypeByIDTy,
+                           GetContainedTypeIDTy)>
+    MDTypeCallbackTy;
+
+struct MetadataLoaderCallbacks {
+  GetTypeByIDTy GetTypeByID;
+  GetContainedTypeIDTy GetContainedTypeID;
+  std::optional<MDTypeCallbackTy> MDType;
+};
+
 /// Helper class that handles loading Metadatas and keeping them available.
 class MetadataLoader {
   class MetadataLoaderImpl;
@@ -39,7 +53,7 @@ class MetadataLoader {
   ~MetadataLoader();
   MetadataLoader(BitstreamCursor &Stream, Module &TheModule,
                  BitcodeReaderValueList &ValueList, bool IsImporting,
-                 std::function<Type *(unsigned)> getTypeByID);
+                 MetadataLoaderCallbacks Callbacks);
   MetadataLoader &operator=(MetadataLoader &&);
   MetadataLoader(MetadataLoader &&);
 

diff  --git a/llvm/lib/IRReader/IRReader.cpp b/llvm/lib/IRReader/IRReader.cpp
index 7765c3fbf2df0..7885c36a79876 100644
--- a/llvm/lib/IRReader/IRReader.cpp
+++ b/llvm/lib/IRReader/IRReader.cpp
@@ -68,14 +68,14 @@ std::unique_ptr<Module> llvm::getLazyIRFileModule(StringRef Filename,
 
 std::unique_ptr<Module> llvm::parseIR(MemoryBufferRef Buffer, SMDiagnostic &Err,
                                       LLVMContext &Context,
-                                      DataLayoutCallbackTy DataLayoutCallback) {
+                                      ParserCallbacks Callbacks) {
   NamedRegionTimer T(TimeIRParsingName, TimeIRParsingDescription,
                      TimeIRParsingGroupName, TimeIRParsingGroupDescription,
                      TimePassesIsEnabled);
   if (isBitcode((const unsigned char *)Buffer.getBufferStart(),
                 (const unsigned char *)Buffer.getBufferEnd())) {
     Expected<std::unique_ptr<Module>> ModuleOrErr =
-        parseBitcodeFile(Buffer, Context, DataLayoutCallback);
+        parseBitcodeFile(Buffer, Context, Callbacks);
     if (Error E = ModuleOrErr.takeError()) {
       handleAllErrors(std::move(E), [&](ErrorInfoBase &EIB) {
         Err = SMDiagnostic(Buffer.getBufferIdentifier(), SourceMgr::DK_Error,
@@ -86,12 +86,14 @@ std::unique_ptr<Module> llvm::parseIR(MemoryBufferRef Buffer, SMDiagnostic &Err,
     return std::move(ModuleOrErr.get());
   }
 
-  return parseAssembly(Buffer, Err, Context, nullptr, DataLayoutCallback);
+  return parseAssembly(Buffer, Err, Context, nullptr,
+                       Callbacks.DataLayout.value_or(
+                           [](StringRef, StringRef) { return std::nullopt; }));
 }
 
-std::unique_ptr<Module>
-llvm::parseIRFile(StringRef Filename, SMDiagnostic &Err, LLVMContext &Context,
-                  DataLayoutCallbackTy DataLayoutCallback) {
+std::unique_ptr<Module> llvm::parseIRFile(StringRef Filename, SMDiagnostic &Err,
+                                          LLVMContext &Context,
+                                          ParserCallbacks Callbacks) {
   ErrorOr<std::unique_ptr<MemoryBuffer>> FileOrErr =
       MemoryBuffer::getFileOrSTDIN(Filename, /*IsText=*/true);
   if (std::error_code EC = FileOrErr.getError()) {
@@ -100,8 +102,7 @@ llvm::parseIRFile(StringRef Filename, SMDiagnostic &Err, LLVMContext &Context,
     return nullptr;
   }
 
-  return parseIR(FileOrErr.get()->getMemBufferRef(), Err, Context,
-                 DataLayoutCallback);
+  return parseIR(FileOrErr.get()->getMemBufferRef(), Err, Context, Callbacks);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/tools/llc/llc.cpp b/llvm/tools/llc/llc.cpp
index d59b128b876d1..e6b65ac446618 100644
--- a/llvm/tools/llc/llc.cpp
+++ b/llvm/tools/llc/llc.cpp
@@ -568,7 +568,8 @@ static int compileModule(char **argv, LLVMContext &Context) {
       if (MIR)
         M = MIR->parseIRModule(SetDataLayout);
     } else {
-      M = parseIRFile(InputFilename, Err, Context, SetDataLayout);
+      M = parseIRFile(InputFilename, Err, Context,
+                      ParserCallbacks(SetDataLayout));
     }
     if (!M) {
       Err.print(argv[0], WithColor::error(errs(), argv[0]));

diff  --git a/llvm/tools/opt/opt.cpp b/llvm/tools/opt/opt.cpp
index 392f0603ab968..40632b43e73bf 100644
--- a/llvm/tools/opt/opt.cpp
+++ b/llvm/tools/opt/opt.cpp
@@ -545,7 +545,8 @@ int main(int argc, char **argv) {
             InputFilename, Err, Context, nullptr, SetDataLayout)
             .Mod;
   else
-    M = parseIRFile(InputFilename, Err, Context, SetDataLayout);
+    M = parseIRFile(InputFilename, Err, Context,
+                    ParserCallbacks(SetDataLayout));
 
   if (!M) {
     Err.print(argv[0], errs());

diff  --git a/llvm/unittests/Bitcode/BitReaderTest.cpp b/llvm/unittests/Bitcode/BitReaderTest.cpp
index c4f9b672ac911..e226df3c8dcab 100644
--- a/llvm/unittests/Bitcode/BitReaderTest.cpp
+++ b/llvm/unittests/Bitcode/BitReaderTest.cpp
@@ -6,11 +6,13 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/Bitcode/BitcodeReader.h"
 #include "llvm/Bitcode/BitcodeWriter.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
@@ -255,4 +257,206 @@ TEST(BitReaderTest, MaterializeFunctionsForBlockAddrInFunctionAfter) {
   EXPECT_FALSE(verifyModule(*M, &dbgs()));
 }
 
+// Helper function to convert type metadata to a string for testing
+static std::string mdToString(Metadata *MD) {
+  std::string S;
+  if (auto *VMD = dyn_cast<ValueAsMetadata>(MD)) {
+    if (VMD->getType()->isPointerTy()) {
+      S += "ptr";
+      return S;
+    }
+  }
+
+  if (auto *TMD = dyn_cast<MDTuple>(MD)) {
+    S += "!{";
+    for (unsigned I = 0; I < TMD->getNumOperands(); I++) {
+      if (I != 0)
+        S += ", ";
+      S += mdToString(TMD->getOperand(I).get());
+    }
+    S += "}";
+  } else if (auto *SMD = dyn_cast<MDString>(MD)) {
+    S += "!'";
+    S += SMD->getString();
+    S += "'";
+  } else if (auto *I = mdconst::dyn_extract<ConstantInt>(MD)) {
+    S += std::to_string(I->getZExtValue());
+  } else if (auto *P = mdconst::dyn_extract<PoisonValue>(MD)) {
+    auto *Ty = P->getType();
+    if (Ty->isIntegerTy()) {
+      S += "i";
+      S += std::to_string(Ty->getIntegerBitWidth());
+    } else if (Ty->isStructTy()) {
+      S += "%";
+      S += Ty->getStructName();
+    } else {
+      llvm_unreachable("unhandled poison metadata");
+    }
+  } else {
+    llvm_unreachable("unhandled metadata");
+  }
+  return S;
+}
+
+// Recursively look into a (pointer) type and the the type.
+// For primitive types it's a poison value of the type, for a pointer it's a
+// metadata tuple with the addrspace and the referenced type. For a function,
+// it's a tuple where the first element is the string "function", the second
+// element is the return type or the string "void" and the following elements
+// are the argument types.
+static Metadata *getTypeMetadataEntry(unsigned TypeID, LLVMContext &Context,
+                                      GetTypeByIDTy GetTypeByID,
+                                      GetContainedTypeIDTy GetContainedTypeID) {
+  Type *Ty = GetTypeByID(TypeID);
+  if (auto *FTy = dyn_cast<FunctionType>(Ty)) {
+    // Save the function signature as metadata
+    SmallVector<Metadata *> SignatureMD;
+    SignatureMD.push_back(MDString::get(Context, "function"));
+    // Return type
+    if (FTy->getReturnType()->isVoidTy())
+      SignatureMD.push_back(MDString::get(Context, "void"));
+    else
+      SignatureMD.push_back(getTypeMetadataEntry(GetContainedTypeID(TypeID, 0),
+                                                 Context, GetTypeByID,
+                                                 GetContainedTypeID));
+    // Arguments
+    for (unsigned I = 0; I != FTy->getNumParams(); ++I)
+      SignatureMD.push_back(
+          getTypeMetadataEntry(GetContainedTypeID(TypeID, I + 1), Context,
+                               GetTypeByID, GetContainedTypeID));
+
+    return MDTuple::get(Context, SignatureMD);
+  }
+
+  if (!Ty->isPointerTy())
+    return ConstantAsMetadata::get(PoisonValue::get(Ty));
+
+  // Return !{<addrspace>, <inner>} for pointer
+  SmallVector<Metadata *, 2> MD;
+  MD.push_back(ConstantAsMetadata::get(ConstantInt::get(
+      Type::getInt32Ty(Context), Ty->getPointerAddressSpace())));
+  MD.push_back(getTypeMetadataEntry(GetContainedTypeID(TypeID, 0), Context,
+                                    GetTypeByID, GetContainedTypeID));
+  return MDTuple::get(Context, MD);
+}
+
+// Test that when reading bitcode with typed pointers and upgrading them to
+// opaque pointers, the type information of function signatures can be extracted
+// and stored in metadata.
+TEST(BitReaderTest, AccessFunctionTypeInfo) {
+  SmallString<1024> Mem;
+  LLVMContext WriteContext;
+  writeModuleToBuffer(
+      parseAssembly(
+          WriteContext,
+          "define void @func() {\n"
+          "  unreachable\n"
+          "}\n"
+          "declare i32 @func_header()\n"
+          "declare i8* @ret_ptr()\n"
+          "declare i8* @ret_and_arg_ptr(i32 addrspace(8)*)\n"
+          "declare i8 addrspace(1)* @double_ptr(i32* addrspace(2)*, i32***)\n"),
+      Mem);
+
+  LLVMContext Context;
+  Context.setOpaquePointers(true);
+
+  ParserCallbacks Callbacks;
+  // Supply a callback that stores the signature of a function into metadata,
+  // so that the types behind pointers can be accessed.
+  // Each function gets a !types metadata, which is a tuple with one element
+  // for a non-void return type and every argument. For primitive types it's
+  // a poison value of the type, for a pointer it's a metadata tuple with
+  // the addrspace and the referenced type.
+  Callbacks.ValueType = [&](Value *V, unsigned TypeID,
+                            GetTypeByIDTy GetTypeByID,
+                            GetContainedTypeIDTy GetContainedTypeID) {
+    if (auto *F = dyn_cast<Function>(V)) {
+      auto *MD = getTypeMetadataEntry(TypeID, F->getContext(), GetTypeByID,
+                                      GetContainedTypeID);
+      F->setMetadata("types", cast<MDNode>(MD));
+    }
+  };
+
+  Expected<std::unique_ptr<Module>> ModuleOrErr =
+      parseBitcodeFile(MemoryBufferRef(Mem.str(), "test"), Context, Callbacks);
+
+  if (!ModuleOrErr)
+    report_fatal_error("Could not parse bitcode module");
+  std::unique_ptr<Module> M = std::move(ModuleOrErr.get());
+
+  EXPECT_EQ(mdToString(M->getFunction("func")->getMetadata("types")),
+            "!{!'function', !'void'}");
+  EXPECT_EQ(mdToString(M->getFunction("func_header")->getMetadata("types")),
+            "!{!'function', i32}");
+  EXPECT_EQ(mdToString(M->getFunction("ret_ptr")->getMetadata("types")),
+            "!{!'function', !{0, i8}}");
+  EXPECT_EQ(mdToString(M->getFunction("ret_and_arg_ptr")->getMetadata("types")),
+            "!{!'function', !{0, i8}, !{8, i32}}");
+  EXPECT_EQ(mdToString(M->getFunction("double_ptr")->getMetadata("types")),
+            "!{!'function', !{1, i8}, !{2, !{0, i32}}, !{0, !{0, !{0, i32}}}}");
+}
+
+// Test that when reading bitcode with typed pointers and upgrading them to
+// opaque pointers, the type information of pointers in metadata can be
+// extracted and stored in metadata.
+TEST(BitReaderTest, AccessMetadataTypeInfo) {
+  SmallString<1024> Mem;
+  LLVMContext WriteContext;
+  writeModuleToBuffer(
+      parseAssembly(WriteContext,
+                    "%dx.types.f32 = type { float }\n"
+                    "declare void @main()\n"
+                    "!md = !{!0}\n"
+                    "!md2 = !{!1}\n"
+                    "!0 = !{i32 2, %dx.types.f32 addrspace(1)* undef, void ()* "
+                    "@main, void() addrspace(3)* null}\n"
+                    "!1 = !{i8*(i32* addrspace(2)*) addrspace(4)* undef, "
+                    "i32*** undef}\n"),
+      Mem);
+
+  LLVMContext Context;
+  Context.setOpaquePointers(true);
+
+  ParserCallbacks Callbacks;
+  // Supply a callback that stores types from metadata,
+  // so that the types behind pointers can be accessed.
+  // Non-pointer entries are ignored. Values with a pointer type are
+  // replaced by a metadata tuple with {original value, type md}. We cannot
+  // save the metadata outside because after conversion to opaque pointers,
+  // entries are not distinguishable anymore (e.g. i32* and i8* are both
+  // upgraded to ptr).
+  Callbacks.MDType = [&](Metadata **Val, unsigned TypeID,
+                         GetTypeByIDTy GetTypeByID,
+                         GetContainedTypeIDTy GetContainedTypeID) {
+    auto *OrigVal = cast<ValueAsMetadata>(*Val);
+    if (OrigVal->getType()->isPointerTy()) {
+      // Ignore function references, their signature can be saved like
+      // in the test above
+      if (!isa<Function>(OrigVal->getValue())) {
+        SmallVector<Metadata *> Tuple;
+        Tuple.push_back(OrigVal);
+        Tuple.push_back(getTypeMetadataEntry(GetContainedTypeID(TypeID, 0),
+                                             OrigVal->getContext(), GetTypeByID,
+                                             GetContainedTypeID));
+        *Val = MDTuple::get(OrigVal->getContext(), Tuple);
+      }
+    }
+  };
+
+  Expected<std::unique_ptr<Module>> ModuleOrErr =
+      parseBitcodeFile(MemoryBufferRef(Mem.str(), "test"), Context, Callbacks);
+
+  if (!ModuleOrErr)
+    report_fatal_error("Could not parse bitcode module");
+  std::unique_ptr<Module> M = std::move(ModuleOrErr.get());
+
+  EXPECT_EQ(
+      mdToString(M->getNamedMetadata("md")->getOperand(0)),
+      "!{2, !{ptr, %dx.types.f32}, ptr, !{ptr, !{!'function', !'void'}}}");
+  EXPECT_EQ(mdToString(M->getNamedMetadata("md2")->getOperand(0)),
+            "!{!{ptr, !{!'function', !{0, i8}, !{2, !{0, i32}}}}, !{ptr, !{0, "
+            "!{0, i32}}}}");
+}
+
 } // end namespace


        


More information about the llvm-commits mailing list