[llvm-branch-commits] [clang] [Serialization] Introduce OnDiskHashTable for specializations (PR #83233)

Chuanqi Xu via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Nov 8 00:53:05 PST 2024


https://github.com/ChuanqiXu9 updated https://github.com/llvm/llvm-project/pull/83233

>From f565dd3f156bbdf608be6643208c40f02b4f0e83 Mon Sep 17 00:00:00 2001
From: Chuanqi Xu <yedeng.yd at linux.alibaba.com>
Date: Wed, 28 Feb 2024 11:41:53 +0800
Subject: [PATCH] [Serialization] Introduce OnDiskHashTable for specializations

Following up for https://github.com/llvm/llvm-project/pull/83108

This follows the suggestion literally from
https://github.com/llvm/llvm-project/pull/76774#issuecomment-1951172457

which introduces OnDiskHashTable for specializations based on
D41416.

Note that I didn't polish this patch to reduce the diff from D41416
to it easier to review. I'll make the polishing patch later. So that we
can focus what we're doing in this patch and focus on the style in the
next patch.
---
 clang/include/clang/AST/ExternalASTSource.h   |  11 +
 .../clang/Sema/MultiplexExternalSemaSource.h  |   6 +
 .../include/clang/Serialization/ASTBitCodes.h |   6 +
 clang/include/clang/Serialization/ASTReader.h |  34 ++-
 clang/include/clang/Serialization/ASTWriter.h |  15 +
 clang/lib/AST/DeclTemplate.cpp                |  17 ++
 clang/lib/AST/ExternalASTSource.cpp           |   5 +
 .../lib/Sema/MultiplexExternalSemaSource.cpp  |  12 +
 clang/lib/Serialization/ASTReader.cpp         | 145 +++++++++-
 clang/lib/Serialization/ASTReaderDecl.cpp     |  27 ++
 clang/lib/Serialization/ASTReaderInternals.h  | 124 +++++++++
 clang/lib/Serialization/ASTWriter.cpp         | 174 +++++++++++-
 clang/lib/Serialization/ASTWriterDecl.cpp     |  32 ++-
 clang/unittests/Serialization/CMakeLists.txt  |   1 +
 .../Serialization/LoadSpecLazilyTest.cpp      | 260 ++++++++++++++++++
 15 files changed, 856 insertions(+), 13 deletions(-)
 create mode 100644 clang/unittests/Serialization/LoadSpecLazilyTest.cpp

diff --git a/clang/include/clang/AST/ExternalASTSource.h b/clang/include/clang/AST/ExternalASTSource.h
index 582ed7c65f58ca..5f4f9a9a8d681e 100644
--- a/clang/include/clang/AST/ExternalASTSource.h
+++ b/clang/include/clang/AST/ExternalASTSource.h
@@ -152,6 +152,17 @@ class ExternalASTSource : public RefCountedBase<ExternalASTSource> {
   virtual bool
   FindExternalVisibleDeclsByName(const DeclContext *DC, DeclarationName Name);
 
+  /// Load all the external specializations for the Decl \param D if \param
+  /// OnlyPartial is false. Otherwise, load all the external **partial**
+  /// specializations for the \param D.
+  virtual void LoadExternalSpecializations(const Decl *D, bool OnlyPartial);
+
+  /// Load all the specializations for the Decl \param D with the same template
+  /// args specified by \param TemplateArgs.
+  virtual void
+  LoadExternalSpecializations(const Decl *D,
+                              ArrayRef<TemplateArgument> TemplateArgs);
+
   /// Ensures that the table of all visible declarations inside this
   /// context is up to date.
   ///
diff --git a/clang/include/clang/Sema/MultiplexExternalSemaSource.h b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
index 3d1906d8699265..78bbbaf2d7b5c6 100644
--- a/clang/include/clang/Sema/MultiplexExternalSemaSource.h
+++ b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
@@ -97,6 +97,12 @@ class MultiplexExternalSemaSource : public ExternalSemaSource {
   bool FindExternalVisibleDeclsByName(const DeclContext *DC,
                                       DeclarationName Name) override;
 
+  void LoadExternalSpecializations(const Decl *D, bool OnlyPartial) override;
+
+  void
+  LoadExternalSpecializations(const Decl *D,
+                              ArrayRef<TemplateArgument> TemplateArgs) override;
+
   /// Ensures that the table of all visible declarations inside this
   /// context is up to date.
   void completeVisibleDeclsMap(const DeclContext *DC) override;
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index e397dff097652b..405954ba922b5e 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -734,6 +734,9 @@ enum ASTRecordTypes {
   /// Record code for Sema's vector of functions/blocks with effects to
   /// be verified.
   DECLS_WITH_EFFECTS_TO_VERIFY = 72,
+
+  /// Record code for updated specialization
+  UPDATE_SPECIALIZATION = 73,
 };
 
 /// Record types used within a source manager block.
@@ -1500,6 +1503,9 @@ enum DeclCode {
   /// A HLSLBufferDecl record.
   DECL_HLSL_BUFFER,
 
+  // A decls specilization record.
+  DECL_SPECIALIZATIONS,
+
   /// An ImplicitConceptSpecializationDecl record.
   DECL_IMPLICIT_CONCEPT_SPECIALIZATION,
 
diff --git a/clang/include/clang/Serialization/ASTReader.h b/clang/include/clang/Serialization/ASTReader.h
index b476a40ebd2c8c..84bb25754f526f 100644
--- a/clang/include/clang/Serialization/ASTReader.h
+++ b/clang/include/clang/Serialization/ASTReader.h
@@ -353,6 +353,9 @@ class ASTIdentifierLookupTrait;
 /// The on-disk hash table(s) used for DeclContext name lookup.
 struct DeclContextLookupTable;
 
+/// The on-disk hash table(s) used for specialization decls.
+struct LazySpecializationInfoLookupTable;
+
 } // namespace reader
 
 } // namespace serialization
@@ -631,20 +634,29 @@ class ASTReader
   llvm::DenseMap<const DeclContext *,
                  serialization::reader::DeclContextLookupTable> Lookups;
 
+  /// Map from decls to specialized decls.
+  llvm::DenseMap<const Decl *,
+                 serialization::reader::LazySpecializationInfoLookupTable>
+      SpecializationsLookups;
+
   // Updates for visible decls can occur for other contexts than just the
   // TU, and when we read those update records, the actual context may not
   // be available yet, so have this pending map using the ID as a key. It
-  // will be realized when the context is actually loaded.
-  struct PendingVisibleUpdate {
+  // will be realized when the data is actually loaded.
+  struct UpdateData {
     ModuleFile *Mod;
     const unsigned char *Data;
   };
-  using DeclContextVisibleUpdates = SmallVector<PendingVisibleUpdate, 1>;
+  using DeclContextVisibleUpdates = SmallVector<UpdateData, 1>;
 
   /// Updates to the visible declarations of declaration contexts that
   /// haven't been loaded yet.
   llvm::DenseMap<GlobalDeclID, DeclContextVisibleUpdates> PendingVisibleUpdates;
 
+  using SpecializationsUpdate = SmallVector<UpdateData, 1>;
+  llvm::DenseMap<GlobalDeclID, SpecializationsUpdate>
+      PendingSpecializationsUpdates;
+
   /// The set of C++ or Objective-C classes that have forward
   /// declarations that have not yet been linked to their definitions.
   llvm::SmallPtrSet<Decl *, 4> PendingDefinitions;
@@ -677,6 +689,11 @@ class ASTReader
                                      llvm::BitstreamCursor &Cursor,
                                      uint64_t Offset, GlobalDeclID ID);
 
+  bool ReadSpecializations(ModuleFile &M, llvm::BitstreamCursor &Cursor,
+                           uint64_t Offset, Decl *D);
+  void AddSpecializations(const Decl *D, const unsigned char *Data,
+                          ModuleFile &M);
+
   /// A vector containing identifiers that have already been
   /// loaded.
   ///
@@ -1379,6 +1396,11 @@ class ASTReader
   const serialization::reader::DeclContextLookupTable *
   getLoadedLookupTables(DeclContext *Primary) const;
 
+  /// Get the loaded specializations lookup tables for \p D,
+  /// if any.
+  serialization::reader::LazySpecializationInfoLookupTable *
+  getLoadedSpecializationsLookupTables(const Decl *D);
+
 private:
   struct ImportedModule {
     ModuleFile *Mod;
@@ -2036,6 +2058,12 @@ class ASTReader
                                       unsigned BlockID,
                                       uint64_t *StartOfBlockOffset = nullptr);
 
+  void LoadExternalSpecializations(const Decl *D, bool OnlyPartial) override;
+
+  void
+  LoadExternalSpecializations(const Decl *D,
+                              ArrayRef<TemplateArgument> TemplateArgs) override;
+
   /// Finds all the visible declarations with a given name.
   /// The current implementation of this method just loads the entire
   /// lookup table as unmaterialized references.
diff --git a/clang/include/clang/Serialization/ASTWriter.h b/clang/include/clang/Serialization/ASTWriter.h
index 198dd01b8d07a0..d2b74d00261e24 100644
--- a/clang/include/clang/Serialization/ASTWriter.h
+++ b/clang/include/clang/Serialization/ASTWriter.h
@@ -426,6 +426,12 @@ class ASTWriter : public ASTDeserializationListener,
   /// Only meaningful for reduced BMI.
   DeclUpdateMap DeclUpdatesFromGMF;
 
+  /// Mapping from decl templates and its new specialization in the
+  /// current TU.
+  using SpecializationUpdateMap =
+      llvm::MapVector<const NamedDecl *, SmallVector<const Decl *>>;
+  SpecializationUpdateMap SpecializationsUpdates;
+
   using FirstLatestDeclMap = llvm::DenseMap<Decl *, Decl *>;
 
   /// Map of first declarations from a chained PCH that point to the
@@ -577,6 +583,12 @@ class ASTWriter : public ASTDeserializationListener,
 
   bool isLookupResultExternal(StoredDeclsList &Result, DeclContext *DC);
 
+  void GenerateSpecializationInfoLookupTable(
+      const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations,
+      llvm::SmallVectorImpl<char> &LookupTable);
+  uint64_t WriteSpecializationInfoLookupTable(
+      const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations);
+
   void GenerateNameLookupTable(const DeclContext *DC,
                                llvm::SmallVectorImpl<char> &LookupTable);
   uint64_t WriteDeclContextLexicalBlock(ASTContext &Context,
@@ -592,6 +604,7 @@ class ASTWriter : public ASTDeserializationListener,
   void WriteDeclAndTypes(ASTContext &Context);
   void PrepareWritingSpecialDecls(Sema &SemaRef);
   void WriteSpecialDeclRecords(Sema &SemaRef);
+  void WriteSpecializationsUpdates();
   void WriteDeclUpdatesBlocks(RecordDataImpl &OffsetsRecord);
   void WriteDeclContextVisibleUpdate(const DeclContext *DC);
   void WriteFPPragmaOptions(const FPOptionsOverride &Opts);
@@ -619,6 +632,8 @@ class ASTWriter : public ASTDeserializationListener,
   unsigned DeclEnumAbbrev = 0;
   unsigned DeclObjCIvarAbbrev = 0;
   unsigned DeclCXXMethodAbbrev = 0;
+  unsigned DeclSpecializationsAbbrev = 0;
+
   unsigned DeclDependentNonTemplateCXXMethodAbbrev = 0;
   unsigned DeclTemplateCXXMethodAbbrev = 0;
   unsigned DeclMemberSpecializedCXXMethodAbbrev = 0;
diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp
index b56dd39b81b835..223c0f29128c44 100644
--- a/clang/lib/AST/DeclTemplate.cpp
+++ b/clang/lib/AST/DeclTemplate.cpp
@@ -344,6 +344,14 @@ RedeclarableTemplateDecl::CommonBase *RedeclarableTemplateDecl::getCommonPtr() c
 
 void RedeclarableTemplateDecl::loadLazySpecializationsImpl(
     bool OnlyPartial /*=false*/) const {
+  auto *ExternalSource = getASTContext().getExternalSource();
+  if (!ExternalSource)
+    return;
+
+  ExternalSource->LoadExternalSpecializations(this->getCanonicalDecl(),
+                                              OnlyPartial);
+  return;
+
   // Grab the most recent declaration to ensure we've loaded any lazy
   // redeclarations of this template.
   CommonBase *CommonBasePtr = getMostRecentDecl()->getCommonPtr();
@@ -363,6 +371,8 @@ void RedeclarableTemplateDecl::loadLazySpecializationsImpl(
 
 Decl *RedeclarableTemplateDecl::loadLazySpecializationImpl(
     LazySpecializationInfo &LazySpecInfo) const {
+  llvm_unreachable("We don't use LazySpecializationInfo any more");
+
   GlobalDeclID ID = LazySpecInfo.DeclID;
   assert(ID.isValid() && "Loading already loaded specialization!");
   // Note that we loaded the specialization.
@@ -373,6 +383,13 @@ Decl *RedeclarableTemplateDecl::loadLazySpecializationImpl(
 
 void RedeclarableTemplateDecl::loadLazySpecializationsImpl(
     ArrayRef<TemplateArgument> Args, TemplateParameterList *TPL) const {
+  auto *ExternalSource = getASTContext().getExternalSource();
+  if (!ExternalSource)
+    return;
+
+  ExternalSource->LoadExternalSpecializations(this->getCanonicalDecl(), Args);
+  return;
+
   CommonBase *CommonBasePtr = getMostRecentDecl()->getCommonPtr();
   if (auto *Specs = CommonBasePtr->LazySpecializations) {
     unsigned Hash = TemplateArgumentList::ComputeODRHash(Args);
diff --git a/clang/lib/AST/ExternalASTSource.cpp b/clang/lib/AST/ExternalASTSource.cpp
index a5b6f80bde694c..122014bfeb2321 100644
--- a/clang/lib/AST/ExternalASTSource.cpp
+++ b/clang/lib/AST/ExternalASTSource.cpp
@@ -98,6 +98,11 @@ ExternalASTSource::FindExternalVisibleDeclsByName(const DeclContext *DC,
   return false;
 }
 
+void ExternalASTSource::LoadExternalSpecializations(const Decl *D, bool) {}
+
+void ExternalASTSource::LoadExternalSpecializations(
+    const Decl *D, ArrayRef<TemplateArgument>) {}
+
 void ExternalASTSource::completeVisibleDeclsMap(const DeclContext *DC) {}
 
 void ExternalASTSource::FindExternalLexicalDecls(
diff --git a/clang/lib/Sema/MultiplexExternalSemaSource.cpp b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
index cd44483b5cbe04..f39463712ce81f 100644
--- a/clang/lib/Sema/MultiplexExternalSemaSource.cpp
+++ b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
@@ -115,6 +115,18 @@ FindExternalVisibleDeclsByName(const DeclContext *DC, DeclarationName Name) {
   return AnyDeclsFound;
 }
 
+void MultiplexExternalSemaSource::LoadExternalSpecializations(
+    const Decl *D, bool OnlyPartial) {
+  for (size_t i = 0; i < Sources.size(); ++i)
+    Sources[i]->LoadExternalSpecializations(D, OnlyPartial);
+}
+
+void MultiplexExternalSemaSource::LoadExternalSpecializations(
+    const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+  for (size_t i = 0; i < Sources.size(); ++i)
+    Sources[i]->LoadExternalSpecializations(D, TemplateArgs);
+}
+
 void MultiplexExternalSemaSource::completeVisibleDeclsMap(const DeclContext *DC){
   for(size_t i = 0; i < Sources.size(); ++i)
     Sources[i]->completeVisibleDeclsMap(DC);
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 1487fd3d0aa7f6..b49651b23e0366 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -1293,6 +1293,43 @@ void ASTDeclContextNameLookupTrait::ReadDataInto(internal_key_type,
   }
 }
 
+ModuleFile *
+LazySpecializationInfoLookupTrait::ReadFileRef(const unsigned char *&d) {
+  using namespace llvm::support;
+
+  uint32_t ModuleFileID =
+      endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+  return Reader.getLocalModuleFile(F, ModuleFileID);
+}
+
+LazySpecializationInfoLookupTrait::internal_key_type
+LazySpecializationInfoLookupTrait::ReadKey(const unsigned char *d, unsigned) {
+  using namespace llvm::support;
+  return endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+}
+
+std::pair<unsigned, unsigned>
+LazySpecializationInfoLookupTrait::ReadKeyDataLength(const unsigned char *&d) {
+  return readULEBKeyDataLength(d);
+}
+
+void LazySpecializationInfoLookupTrait::ReadDataInto(internal_key_type,
+                                                     const unsigned char *d,
+                                                     unsigned DataLen,
+                                                     data_type_builder &Val) {
+  using namespace llvm::support;
+
+  for (unsigned NumDecls =
+           DataLen / serialization::reader::LazySpecializationInfo::Length;
+       NumDecls; --NumDecls) {
+    LocalDeclID LocalID =
+        LocalDeclID::get(Reader, F, endian::readNext<DeclID, llvm::endianness::little, unaligned>(d));
+    const bool IsPartial =
+        endian::readNext<bool, llvm::endianness::little, unaligned>(d);
+    Val.insert({Reader.getGlobalDeclID(F, LocalID), IsPartial});
+  }
+}
+
 bool ASTReader::ReadLexicalDeclContextStorage(ModuleFile &M,
                                               BitstreamCursor &Cursor,
                                               uint64_t Offset,
@@ -1377,7 +1414,49 @@ bool ASTReader::ReadVisibleDeclContextStorage(ModuleFile &M,
   // We can't safely determine the primary context yet, so delay attaching the
   // lookup table until we're done with recursive deserialization.
   auto *Data = (const unsigned char*)Blob.data();
-  PendingVisibleUpdates[ID].push_back(PendingVisibleUpdate{&M, Data});
+  PendingVisibleUpdates[ID].push_back(UpdateData{&M, Data});
+  return false;
+}
+
+void ASTReader::AddSpecializations(const Decl *D, const unsigned char *Data,
+                                   ModuleFile &M) {
+  D = D->getCanonicalDecl();
+  SpecializationsLookups[D].Table.add(
+      &M, Data, reader::LazySpecializationInfoLookupTrait(*this, M));
+}
+
+bool ASTReader::ReadSpecializations(ModuleFile &M, BitstreamCursor &Cursor,
+                                    uint64_t Offset, Decl *D) {
+  assert(Offset != 0);
+
+  SavedStreamPosition SavedPosition(Cursor);
+  if (llvm::Error Err = Cursor.JumpToBit(Offset)) {
+    Error(std::move(Err));
+    return true;
+  }
+
+  RecordData Record;
+  StringRef Blob;
+  Expected<unsigned> MaybeCode = Cursor.ReadCode();
+  if (!MaybeCode) {
+    Error(MaybeCode.takeError());
+    return true;
+  }
+  unsigned Code = MaybeCode.get();
+
+  Expected<unsigned> MaybeRecCode = Cursor.readRecord(Code, Record, &Blob);
+  if (!MaybeRecCode) {
+    Error(MaybeRecCode.takeError());
+    return true;
+  }
+  unsigned RecCode = MaybeRecCode.get();
+  if (RecCode != DECL_SPECIALIZATIONS) {
+    Error("Expected decl specs block");
+    return true;
+  }
+
+  auto *Data = (const unsigned char *)Blob.data();
+  AddSpecializations(D, Data, M);
   return false;
 }
 
@@ -3463,7 +3542,20 @@ llvm::Error ASTReader::ReadASTBlock(ModuleFile &F,
       unsigned Idx = 0;
       GlobalDeclID ID = ReadDeclID(F, Record, Idx);
       auto *Data = (const unsigned char*)Blob.data();
-      PendingVisibleUpdates[ID].push_back(PendingVisibleUpdate{&F, Data});
+      PendingVisibleUpdates[ID].push_back(UpdateData{&F, Data});
+      // If we've already loaded the decl, perform the updates when we finish
+      // loading this block.
+      if (Decl *D = GetExistingDecl(ID))
+        PendingUpdateRecords.push_back(
+            PendingUpdateRecord(ID, D, /*JustLoaded=*/false));
+      break;
+    }
+
+    case UPDATE_SPECIALIZATION: {
+      unsigned Idx = 0;
+      GlobalDeclID ID = ReadDeclID(F, Record, Idx);
+      auto *Data = (const unsigned char *)Blob.data();
+      PendingSpecializationsUpdates[ID].push_back(UpdateData{&F, Data});
       // If we've already loaded the decl, perform the updates when we finish
       // loading this block.
       if (Decl *D = GetExistingDecl(ID))
@@ -8057,6 +8149,48 @@ Stmt *ASTReader::GetExternalDeclStmt(uint64_t Offset) {
   return ReadStmtFromStream(*Loc.F);
 }
 
+void ASTReader::LoadExternalSpecializations(const Decl *D, bool OnlyPartial) {
+  assert(D);
+
+  auto It = SpecializationsLookups.find(D);
+  if (It == SpecializationsLookups.end())
+    return;
+
+  // Get Decl may violate the iterator from SpecializationsLookups so we store
+  // the DeclIDs in ahead.
+  llvm::SmallVector<serialization::reader::LazySpecializationInfo, 8> Infos =
+      It->second.Table.findAll();
+
+  // Since we've loaded all the specializations, we can erase it from
+  // the lookup table.
+  if (!OnlyPartial)
+    SpecializationsLookups.erase(It);
+
+  Deserializing LookupResults(this);
+  for (auto &Info : Infos)
+    if (!OnlyPartial || Info.IsPartial)
+      GetDecl(Info.ID);
+}
+
+void ASTReader::LoadExternalSpecializations(
+    const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+  assert(D);
+
+  auto It = SpecializationsLookups.find(D);
+  if (It == SpecializationsLookups.end())
+    return;
+
+  Deserializing LookupResults(this);
+  auto HashValue = TemplateArgumentList::ComputeODRHash(TemplateArgs);
+
+  // Get Decl may violate the iterator from SpecializationsLookups
+  llvm::SmallVector<serialization::reader::LazySpecializationInfo, 8> Infos =
+      It->second.Table.find(HashValue);
+
+  for (auto &Info : Infos)
+    GetDecl(Info.ID);
+}
+
 void ASTReader::FindExternalLexicalDecls(
     const DeclContext *DC, llvm::function_ref<bool(Decl::Kind)> IsKindWeWant,
     SmallVectorImpl<Decl *> &Decls) {
@@ -8236,6 +8370,13 @@ ASTReader::getLoadedLookupTables(DeclContext *Primary) const {
   return I == Lookups.end() ? nullptr : &I->second;
 }
 
+serialization::reader::LazySpecializationInfoLookupTable *
+ASTReader::getLoadedSpecializationsLookupTables(const Decl *D) {
+  assert(D->isCanonicalDecl());
+  auto I = SpecializationsLookups.find(D);
+  return I == SpecializationsLookups.end() ? nullptr : &I->second;
+}
+
 /// Under non-PCH compilation the consumer receives the objc methods
 /// before receiving the implementation, and codegen depends on this.
 /// We simulate this by deserializing and passing to consumer the methods of the
diff --git a/clang/lib/Serialization/ASTReaderDecl.cpp b/clang/lib/Serialization/ASTReaderDecl.cpp
index 572d838ba9f72c..b7c8639953d269 100644
--- a/clang/lib/Serialization/ASTReaderDecl.cpp
+++ b/clang/lib/Serialization/ASTReaderDecl.cpp
@@ -342,6 +342,9 @@ class ASTDeclReader : public DeclVisitor<ASTDeclReader, void> {
   static void markIncompleteDeclChainImpl(Redeclarable<DeclT> *D);
   static void markIncompleteDeclChainImpl(...);
 
+  void ReadSpecializations(ModuleFile &M, Decl *D,
+                           llvm::BitstreamCursor &DeclsCursor);
+
   void ReadFunctionDefinition(FunctionDecl *FD);
   void Visit(Decl *D);
 
@@ -2421,6 +2424,14 @@ void ASTDeclReader::VisitImplicitConceptSpecializationDecl(
 void ASTDeclReader::VisitRequiresExprBodyDecl(RequiresExprBodyDecl *D) {
 }
 
+void ASTDeclReader::ReadSpecializations(ModuleFile &M, Decl *D,
+                                        llvm::BitstreamCursor &DeclsCursor) {
+  uint64_t Offset = ReadLocalOffset();
+  bool Failed = Reader.ReadSpecializations(M, DeclsCursor, Offset, D);
+  (void)Failed;
+  assert(!Failed);
+}
+
 RedeclarableResult
 ASTDeclReader::VisitRedeclarableTemplateDecl(RedeclarableTemplateDecl *D) {
   RedeclarableResult Redecl = VisitRedeclarable(D);
@@ -2462,6 +2473,7 @@ void ASTDeclReader::VisitClassTemplateDecl(ClassTemplateDecl *D) {
     SmallVector<LazySpecializationInfo, 32> SpecIDs;
     readDeclIDList(SpecIDs);
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
+    ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor);
   }
 
   if (D->getTemplatedDecl()->TemplateOrInstantiation) {
@@ -2490,6 +2502,7 @@ void ASTDeclReader::VisitVarTemplateDecl(VarTemplateDecl *D) {
     SmallVector<LazySpecializationInfo, 32> SpecIDs;
     readDeclIDList(SpecIDs);
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
+    ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor);
   }
 }
 
@@ -2591,6 +2604,7 @@ void ASTDeclReader::VisitFunctionTemplateDecl(FunctionTemplateDecl *D) {
     SmallVector<LazySpecializationInfo, 32> SpecIDs;
     readDeclIDList(SpecIDs);
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
+    ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor);
   }
 }
 
@@ -3880,6 +3894,7 @@ Decl *ASTReader::ReadDeclRecord(GlobalDeclID ID) {
   switch ((DeclCode)MaybeDeclCode.get()) {
   case DECL_CONTEXT_LEXICAL:
   case DECL_CONTEXT_VISIBLE:
+  case DECL_SPECIALIZATIONS:
     llvm_unreachable("Record cannot be de-serialized with readDeclRecord");
   case DECL_TYPEDEF:
     D = TypedefDecl::CreateDeserialized(Context, ID);
@@ -4343,12 +4358,14 @@ void ASTReader::loadDeclUpdateRecords(PendingUpdateRecord &Record) {
   assert((PendingLazySpecializationIDs.empty() || isa<ClassTemplateDecl>(D) ||
           isa<FunctionTemplateDecl, VarTemplateDecl>(D)) &&
          "Must not have pending specializations");
+  /*
   if (auto *CTD = dyn_cast<ClassTemplateDecl>(D))
     ASTDeclReader::AddLazySpecializations(CTD, PendingLazySpecializationIDs);
   else if (auto *FTD = dyn_cast<FunctionTemplateDecl>(D))
     ASTDeclReader::AddLazySpecializations(FTD, PendingLazySpecializationIDs);
   else if (auto *VTD = dyn_cast<VarTemplateDecl>(D))
     ASTDeclReader::AddLazySpecializations(VTD, PendingLazySpecializationIDs);
+  */
   PendingLazySpecializationIDs.clear();
 
   // Load the pending visible updates for this decl context, if it has any.
@@ -4374,6 +4391,16 @@ void ASTReader::loadDeclUpdateRecords(PendingUpdateRecord &Record) {
       FunctionToLambdasMap.erase(IT);
     }
   }
+
+  // Load the pending specializations update for this decl, if it has any.
+  if (auto I = PendingSpecializationsUpdates.find(ID);
+      I != PendingSpecializationsUpdates.end()) {
+    auto SpecializationUpdates = std::move(I->second);
+    PendingSpecializationsUpdates.erase(I);
+
+    for (const auto &Update : SpecializationUpdates)
+      AddSpecializations(D, Update.Data, *Update.Mod);
+  }
 }
 
 void ASTReader::loadPendingDeclChain(Decl *FirstLocal, uint64_t LocalOffset) {
diff --git a/clang/lib/Serialization/ASTReaderInternals.h b/clang/lib/Serialization/ASTReaderInternals.h
index 536b19f91691eb..43b8aa41f91a20 100644
--- a/clang/lib/Serialization/ASTReaderInternals.h
+++ b/clang/lib/Serialization/ASTReaderInternals.h
@@ -119,6 +119,102 @@ struct DeclContextLookupTable {
   MultiOnDiskHashTable<ASTDeclContextNameLookupTrait> Table;
 };
 
+struct LazySpecializationInfo {
+  // The Decl ID for the specialization.
+  GlobalDeclID ID;
+  // Whether or not this specialization is partial.
+  bool IsPartial;
+
+  bool operator==(const LazySpecializationInfo &Other) {
+    assert(ID != Other.ID || IsPartial == Other.IsPartial);
+    return ID == Other.ID;
+  }
+
+  // Records the size record in OnDiskHashTable.
+  // sizeof() may return 8 due to align requirements.
+  static constexpr unsigned Length = sizeof(DeclID) + sizeof(IsPartial);
+};
+
+/// Class that performs lookup to specialized decls.
+class LazySpecializationInfoLookupTrait {
+  ASTReader &Reader;
+  ModuleFile &F;
+
+public:
+  // Maximum number of lookup tables we allow before condensing the tables.
+  static const int MaxTables = 4;
+
+  /// The lookup result is a list of global declaration IDs.
+  using data_type = SmallVector<LazySpecializationInfo, 4>;
+
+  struct data_type_builder {
+    data_type &Data;
+    llvm::DenseSet<LazySpecializationInfo> Found;
+
+    data_type_builder(data_type &D) : Data(D) {}
+
+    void insert(LazySpecializationInfo Info) {
+      // Just use a linear scan unless we have more than a few IDs.
+      if (Found.empty() && !Data.empty()) {
+        if (Data.size() <= 4) {
+          for (auto I : Found)
+            if (I == Info)
+              return;
+          Data.push_back(Info);
+          return;
+        }
+
+        // Switch to tracking found IDs in the set.
+        Found.insert(Data.begin(), Data.end());
+      }
+
+      if (Found.insert(Info).second)
+        Data.push_back(Info);
+    }
+  };
+  using hash_value_type = unsigned;
+  using offset_type = unsigned;
+  using file_type = ModuleFile *;
+
+  using external_key_type = unsigned;
+  using internal_key_type = unsigned;
+
+  explicit LazySpecializationInfoLookupTrait(ASTReader &Reader, ModuleFile &F)
+      : Reader(Reader), F(F) {}
+
+  static bool EqualKey(const internal_key_type &a, const internal_key_type &b) {
+    return a == b;
+  }
+
+  static hash_value_type ComputeHash(const internal_key_type &Key) {
+    return Key;
+  }
+
+  static internal_key_type GetInternalKey(const external_key_type &Name) {
+    return Name;
+  }
+
+  static std::pair<unsigned, unsigned>
+  ReadKeyDataLength(const unsigned char *&d);
+
+  internal_key_type ReadKey(const unsigned char *d, unsigned);
+
+  void ReadDataInto(internal_key_type, const unsigned char *d, unsigned DataLen,
+                    data_type_builder &Val);
+
+  static void MergeDataInto(const data_type &From, data_type_builder &To) {
+    To.Data.reserve(To.Data.size() + From.size());
+    for (LazySpecializationInfo Info : From)
+      To.insert(Info);
+  }
+
+  file_type ReadFileRef(const unsigned char *&d);
+};
+
+struct LazySpecializationInfoLookupTable {
+  MultiOnDiskHashTable<LazySpecializationInfoLookupTrait> Table;
+};
+
 /// Base class for the trait describing the on-disk hash table for the
 /// identifiers in an AST file.
 ///
@@ -291,4 +387,32 @@ using HeaderFileInfoLookupTable =
 
 } // namespace clang
 
+namespace llvm {
+// ID is unique in LazySpecializationInfo, it is redundant to calculate
+// IsPartial.
+template <>
+struct DenseMapInfo<clang::serialization::reader::LazySpecializationInfo> {
+  using LazySpecializationInfo =
+      clang::serialization::reader::LazySpecializationInfo;
+  using Wrapped = DenseMapInfo<clang::serialization::DeclID>;
+
+  static inline LazySpecializationInfo getEmptyKey() {
+    return {(clang::GlobalDeclID)Wrapped::getEmptyKey(), false};
+  }
+
+  static inline LazySpecializationInfo getTombstoneKey() {
+    return {(clang::GlobalDeclID)Wrapped::getTombstoneKey(), false};
+  }
+
+  static unsigned getHashValue(const LazySpecializationInfo &Key) {
+    return Wrapped::getHashValue(Key.ID.getRawValue());
+  }
+
+  static bool isEqual(const LazySpecializationInfo &LHS,
+                      const LazySpecializationInfo &RHS) {
+    return LHS.ID == RHS.ID;
+  }
+};
+} // end namespace llvm
+
 #endif // LLVM_CLANG_LIB_SERIALIZATION_ASTREADERINTERNALS_H
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 6185749c9a9de8..8c48f8bbfe4e44 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -4126,6 +4126,155 @@ class ASTDeclContextNameLookupTrait {
 
 } // namespace
 
+namespace {
+class LazySpecializationInfoLookupTrait {
+  ASTWriter &Writer;
+  llvm::SmallVector<serialization::reader::LazySpecializationInfo, 64> Specs;
+
+public:
+  using key_type = unsigned;
+  using key_type_ref = key_type;
+
+  /// A start and end index into Specs, representing a sequence of decls.
+  using data_type = std::pair<unsigned, unsigned>;
+  using data_type_ref = const data_type &;
+
+  using hash_value_type = unsigned;
+  using offset_type = unsigned;
+
+  explicit LazySpecializationInfoLookupTrait(ASTWriter &Writer)
+      : Writer(Writer) {}
+
+  template <typename Col> data_type getData(Col &&C) {
+    unsigned Start = Specs.size();
+    for (auto *D : C) {
+      bool IsPartial = isa<ClassTemplatePartialSpecializationDecl,
+                           VarTemplatePartialSpecializationDecl>(D);
+      NamedDecl *ND = getDeclForLocalLookup(
+                           Writer.getLangOpts(), const_cast<NamedDecl *>(D));
+      Specs.push_back({GlobalDeclID(Writer.GetDeclRef(ND).getRawValue()),
+                       IsPartial});
+    }
+    return std::make_pair(Start, Specs.size());
+  }
+
+  data_type ImportData(
+      const reader::LazySpecializationInfoLookupTrait::data_type &FromReader) {
+    unsigned Start = Specs.size();
+    for (auto ID : FromReader)
+      Specs.push_back(ID);
+    return std::make_pair(Start, Specs.size());
+  }
+
+  static bool EqualKey(key_type_ref a, key_type_ref b) { return a == b; }
+
+  hash_value_type ComputeHash(key_type Name) { return Name; }
+
+  void EmitFileRef(raw_ostream &Out, ModuleFile *F) const {
+    assert(Writer.hasChain() &&
+           "have reference to loaded module file but no chain?");
+
+    using namespace llvm::support;
+    endian::write<uint32_t>(Out, Writer.getChain()->getModuleFileID(F),
+                            llvm::endianness::little);
+  }
+
+  std::pair<unsigned, unsigned> EmitKeyDataLength(raw_ostream &Out,
+                                                  key_type HashValue,
+                                                  data_type_ref Lookup) {
+    // 4 bytes for each slot.
+    unsigned KeyLen = 4;
+    unsigned DataLen = serialization::reader::LazySpecializationInfo::Length *
+                       (Lookup.second - Lookup.first);
+
+    return emitULEBKeyDataLength(KeyLen, DataLen, Out);
+  }
+
+  void EmitKey(raw_ostream &Out, key_type HashValue, unsigned) {
+    using namespace llvm::support;
+
+    endian::Writer LE(Out, llvm::endianness::little);
+    LE.write<uint32_t>(HashValue);
+  }
+
+  void EmitData(raw_ostream &Out, key_type_ref, data_type Lookup,
+                unsigned DataLen) {
+    using namespace llvm::support;
+
+    endian::Writer LE(Out, llvm::endianness::little);
+    uint64_t Start = Out.tell();
+    (void)Start;
+    for (unsigned I = Lookup.first, N = Lookup.second; I != N; ++I) {
+      LE.write<DeclID>(Specs[I].ID.getRawValue());
+      LE.write<bool>(Specs[I].IsPartial);
+    }
+    assert(Out.tell() - Start == DataLen && "Data length is wrong");
+  }
+};
+
+unsigned CalculateODRHashForSpecs(const Decl *Spec) {
+  ArrayRef<TemplateArgument> Args;
+  if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(Spec))
+    Args = CTSD->getTemplateArgs().asArray();
+  else if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(Spec))
+    Args = VTSD->getTemplateArgs().asArray();
+  else if (auto *FD = dyn_cast<FunctionDecl>(Spec))
+    Args = FD->getTemplateSpecializationArgs()->asArray();
+  else
+    llvm_unreachable("New Specialization Kind?");
+
+  return TemplateArgumentList::ComputeODRHash(Args);
+}
+} // namespace
+
+void ASTWriter::GenerateSpecializationInfoLookupTable(
+    const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations,
+    llvm::SmallVectorImpl<char> &LookupTable) {
+  assert(D->isFirstDecl());
+
+  // Create the on-disk hash table representation.
+  MultiOnDiskHashTableGenerator<reader::LazySpecializationInfoLookupTrait,
+                                LazySpecializationInfoLookupTrait>
+      Generator;
+  LazySpecializationInfoLookupTrait Trait(*this);
+
+  llvm::DenseMap<unsigned, llvm::SmallVector<const NamedDecl *, 4>>
+      SpecializationMaps;
+
+  for (auto *Specialization : Specializations) {
+    unsigned HashedValue = CalculateODRHashForSpecs(Specialization);
+
+    auto Iter = SpecializationMaps.find(HashedValue);
+    if (Iter == SpecializationMaps.end())
+      Iter = SpecializationMaps
+                 .try_emplace(HashedValue,
+                              llvm::SmallVector<const NamedDecl *, 4>())
+                 .first;
+
+    Iter->second.push_back(cast<NamedDecl>(Specialization));
+  }
+
+  for (auto Iter : SpecializationMaps)
+    Generator.insert(Iter.first, Trait.getData(Iter.second), Trait);
+
+  auto *Lookups =
+      Chain ? Chain->getLoadedSpecializationsLookupTables(D) : nullptr;
+  Generator.emit(LookupTable, Trait, Lookups ? &Lookups->Table : nullptr);
+}
+
+uint64_t ASTWriter::WriteSpecializationInfoLookupTable(
+    const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations) {
+
+  llvm::SmallString<4096> LookupTable;
+  GenerateSpecializationInfoLookupTable(D, Specializations, LookupTable);
+
+  uint64_t Offset = Stream.GetCurrentBitNo();
+  RecordData::value_type Record[] = {DECL_SPECIALIZATIONS};
+  Stream.EmitRecordWithBlob(DeclSpecializationsAbbrev, Record, LookupTable);
+
+  return Offset;
+}
+
 bool ASTWriter::isLookupResultExternal(StoredDeclsList &Result,
                                        DeclContext *DC) {
   return Result.hasExternalDecls() &&
@@ -5666,7 +5815,7 @@ void ASTWriter::WriteDeclAndTypes(ASTContext &Context) {
   // Keep writing types, declarations, and declaration update records
   // until we've emitted all of them.
   RecordData DeclUpdatesOffsetsRecord;
-  Stream.EnterSubblock(DECLTYPES_BLOCK_ID, /*bits for abbreviations*/5);
+  Stream.EnterSubblock(DECLTYPES_BLOCK_ID, /*bits for abbreviations*/6);
   DeclTypesBlockStartOffset = Stream.GetCurrentBitNo();
   WriteTypeAbbrevs();
   WriteDeclAbbrevs();
@@ -5740,6 +5889,9 @@ void ASTWriter::WriteDeclAndTypes(ASTContext &Context) {
                       FunctionToLambdaMapAbbrev);
   }
 
+  if (!SpecializationsUpdates.empty())
+    WriteSpecializationsUpdates();
+
   const TranslationUnitDecl *TU = Context.getTranslationUnitDecl();
   // Create a lexical update block containing all of the declarations in the
   // translation unit that do not come from other AST files.
@@ -5783,6 +5935,26 @@ void ASTWriter::WriteDeclAndTypes(ASTContext &Context) {
     WriteDeclContextVisibleUpdate(DC);
 }
 
+void ASTWriter::WriteSpecializationsUpdates() {
+  auto Abv = std::make_shared<llvm::BitCodeAbbrev>();
+  Abv->Add(llvm::BitCodeAbbrevOp(UPDATE_SPECIALIZATION));
+  Abv->Add(llvm::BitCodeAbbrevOp(llvm::BitCodeAbbrevOp::VBR, 6));
+  Abv->Add(llvm::BitCodeAbbrevOp(llvm::BitCodeAbbrevOp::Blob));
+  auto UpdateSpecializationAbbrev = Stream.EmitAbbrev(std::move(Abv));
+
+  for (auto &SpecializationUpdate : SpecializationsUpdates) {
+    const NamedDecl *D = SpecializationUpdate.first;
+
+    llvm::SmallString<4096> LookupTable;
+    GenerateSpecializationInfoLookupTable(D, SpecializationUpdate.second,
+                                          LookupTable);
+
+    // Write the lookup table
+    RecordData::value_type Record[] = {UPDATE_SPECIALIZATION, getDeclID(D).getRawValue()};
+    Stream.EmitRecordWithBlob(UpdateSpecializationAbbrev, Record, LookupTable);
+  }
+}
+
 void ASTWriter::WriteDeclUpdatesBlocks(RecordDataImpl &OffsetsRecord) {
   if (DeclUpdates.empty())
     return;
diff --git a/clang/lib/Serialization/ASTWriterDecl.cpp b/clang/lib/Serialization/ASTWriterDecl.cpp
index e5ada3fdbb1c48..d2766ceae9b8c9 100644
--- a/clang/lib/Serialization/ASTWriterDecl.cpp
+++ b/clang/lib/Serialization/ASTWriterDecl.cpp
@@ -208,15 +208,17 @@ namespace clang {
     /// file that provides a declaration of D. We store the DeclId and an
     /// ODRHash of the template arguments of D which should provide enough
     /// information to load D only if the template instantiator needs it.
-    void AddFirstSpecializationDeclFromEachModule(const Decl *D,
-                                                  bool IncludeLocal) {
-      assert(isa<ClassTemplateSpecializationDecl>(D) ||
-             isa<VarTemplateSpecializationDecl>(D) ||
-             isa<FunctionDecl>(D) && "Must not be called with other decls");
+    void AddFirstSpecializationDeclFromEachModule(
+        const Decl *D, llvm::SmallVectorImpl<const Decl *> &SpecsInMap) {
+      assert((isa<ClassTemplateSpecializationDecl>(D) ||
+              isa<VarTemplateSpecializationDecl>(D) || isa<FunctionDecl>(D)) &&
+             "Must not be called with other decls");
       llvm::MapVector<ModuleFile *, const Decl *> Firsts;
-      CollectFirstDeclFromEachModule(D, IncludeLocal, Firsts);
+      CollectFirstDeclFromEachModule(D, /*IncludeLocal*/ true, Firsts);
 
       for (const auto &F : Firsts) {
+        SpecsInMap.push_back(F.second);
+
         Record.AddDeclRef(F.second);
         ArrayRef<TemplateArgument> Args;
         if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(D))
@@ -283,10 +285,15 @@ namespace clang {
       for (auto &Entry : getPartialSpecializations(Common))
         Specs.push_back(getSpecializationDecl(Entry));
 
+      llvm::SmallVector<const Decl *, 16> SpecsInOnDiskMap = Specs;
+
       for (auto *D : Specs) {
         assert(D->isCanonicalDecl() && "non-canonical decl in set");
-        AddFirstSpecializationDeclFromEachModule(D, /*IncludeLocal*/ true);
+        AddFirstSpecializationDeclFromEachModule(D, SpecsInOnDiskMap);
       }
+
+      // We don't need to insert LazySpecializations to SpecsInOnDiskMap,
+      // since we'll handle that in GenerateSpecializationInfoLookupTable.
       for (auto &SpecInfo : LazySpecializations) {
         Record.push_back(SpecInfo.DeclID.getRawValue());
         Record.push_back(SpecInfo.ODRHash);
@@ -299,6 +306,9 @@ namespace clang {
       assert((Record.size() - I - 1) % 3 == 0 &&
              "Must be divisible by LazySpecializationInfo count!");
       Record[I] = (Record.size() - I - 1) / 3;
+
+      Record.AddOffset(
+          Writer.WriteSpecializationInfoLookupTable(D, SpecsInOnDiskMap));
     }
 
     /// Ensure that this template specialization is associated with the specified
@@ -321,6 +331,9 @@ namespace clang {
 
       Writer.DeclUpdates[Template].push_back(ASTWriter::DeclUpdate(
           UPD_CXX_ADDED_TEMPLATE_SPECIALIZATION, Specialization));
+
+      Writer.SpecializationsUpdates[cast<NamedDecl>(Template)].push_back(
+          cast<NamedDecl>(Specialization));
     }
   };
 }
@@ -2816,6 +2829,11 @@ void ASTWriter::WriteDeclAbbrevs() {
   Abv->Add(BitCodeAbbrevOp(serialization::DECL_CONTEXT_VISIBLE));
   Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Blob));
   DeclContextVisibleLookupAbbrev = Stream.EmitAbbrev(std::move(Abv));
+
+  Abv = std::make_shared<BitCodeAbbrev>();
+  Abv->Add(BitCodeAbbrevOp(serialization::DECL_SPECIALIZATIONS));
+  Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Blob));
+  DeclSpecializationsAbbrev = Stream.EmitAbbrev(std::move(Abv));
 }
 
 /// isRequiredDecl - Check if this is a "required" Decl, which must be seen by
diff --git a/clang/unittests/Serialization/CMakeLists.txt b/clang/unittests/Serialization/CMakeLists.txt
index e7eebd0cb98239..e7005b5d511eba 100644
--- a/clang/unittests/Serialization/CMakeLists.txt
+++ b/clang/unittests/Serialization/CMakeLists.txt
@@ -11,6 +11,7 @@ add_clang_unittest(SerializationTests
   ModuleCacheTest.cpp
   NoCommentsTest.cpp
   PreambleInNamedModulesTest.cpp
+  LoadSpecLazilyTest.cpp
   SourceLocationEncodingTest.cpp
   VarDeclConstantInitTest.cpp
   )
diff --git a/clang/unittests/Serialization/LoadSpecLazilyTest.cpp b/clang/unittests/Serialization/LoadSpecLazilyTest.cpp
new file mode 100644
index 00000000000000..76e3ccae3d3c3d
--- /dev/null
+++ b/clang/unittests/Serialization/LoadSpecLazilyTest.cpp
@@ -0,0 +1,260 @@
+//== unittests/Serialization/LoadSpecLazily.cpp ----------------------========//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Frontend/FrontendAction.h"
+#include "clang/Frontend/FrontendActions.h"
+#include "clang/Parse/ParseAST.h"
+#include "clang/Serialization/ASTDeserializationListener.h"
+#include "clang/Tooling/Tooling.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using namespace clang;
+using namespace clang::tooling;
+
+namespace {
+
+class LoadSpecLazilyTest : public ::testing::Test {
+  void SetUp() override {
+    ASSERT_FALSE(
+        sys::fs::createUniqueDirectory("load-spec-lazily-test", TestDir));
+  }
+
+  void TearDown() override { sys::fs::remove_directories(TestDir); }
+
+public:
+  SmallString<256> TestDir;
+
+  void addFile(StringRef Path, StringRef Contents) {
+    ASSERT_FALSE(sys::path::is_absolute(Path));
+
+    SmallString<256> AbsPath(TestDir);
+    sys::path::append(AbsPath, Path);
+
+    ASSERT_FALSE(
+        sys::fs::create_directories(llvm::sys::path::parent_path(AbsPath)));
+
+    std::error_code EC;
+    llvm::raw_fd_ostream OS(AbsPath, EC);
+    ASSERT_FALSE(EC);
+    OS << Contents;
+  }
+
+  std::string GenerateModuleInterface(StringRef ModuleName,
+                                      StringRef Contents) {
+    std::string FileName = llvm::Twine(ModuleName + ".cppm").str();
+    addFile(FileName, Contents);
+
+    IntrusiveRefCntPtr<DiagnosticsEngine> Diags =
+        CompilerInstance::createDiagnostics(new DiagnosticOptions());
+    CreateInvocationOptions CIOpts;
+    CIOpts.Diags = Diags;
+    CIOpts.VFS = llvm::vfs::createPhysicalFileSystem();
+
+    std::string CacheBMIPath =
+        llvm::Twine(TestDir + "/" + ModuleName + ".pcm").str();
+    std::string PrebuiltModulePath =
+        "-fprebuilt-module-path=" + TestDir.str().str();
+    const char *Args[] = {"clang++",
+                          "-std=c++20",
+                          "--precompile",
+                          PrebuiltModulePath.c_str(),
+                          "-working-directory",
+                          TestDir.c_str(),
+                          "-I",
+                          TestDir.c_str(),
+                          FileName.c_str(),
+                          "-o",
+                          CacheBMIPath.c_str()};
+    std::shared_ptr<CompilerInvocation> Invocation =
+        createInvocation(Args, CIOpts);
+    EXPECT_TRUE(Invocation);
+
+    CompilerInstance Instance;
+    Instance.setDiagnostics(Diags.get());
+    Instance.setInvocation(Invocation);
+    Instance.getFrontendOpts().OutputFile = CacheBMIPath;
+    GenerateModuleInterfaceAction Action;
+    EXPECT_TRUE(Instance.ExecuteAction(Action));
+    EXPECT_FALSE(Diags->hasErrorOccurred());
+
+    return CacheBMIPath;
+  }
+};
+
+enum class CheckingMode { Forbidden, Required };
+
+class DeclsReaderListener : public ASTDeserializationListener {
+  StringRef SpeficiedName;
+  CheckingMode Mode;
+
+  bool ReadedSpecifiedName = false;
+
+public:
+  void DeclRead(GlobalDeclID ID, const Decl *D) override {
+    auto *ND = dyn_cast<NamedDecl>(D);
+    if (!ND)
+      return;
+
+    ReadedSpecifiedName |= ND->getName().contains(SpeficiedName);
+    if (Mode == CheckingMode::Forbidden) {
+      EXPECT_FALSE(ReadedSpecifiedName);
+    }
+  }
+
+  DeclsReaderListener(StringRef SpeficiedName, CheckingMode Mode)
+      : SpeficiedName(SpeficiedName), Mode(Mode) {}
+
+  ~DeclsReaderListener() {
+    if (Mode == CheckingMode::Required) {
+      EXPECT_TRUE(ReadedSpecifiedName);
+    }
+  }
+};
+
+class LoadSpecLazilyConsumer : public ASTConsumer {
+  DeclsReaderListener Listener;
+
+public:
+  LoadSpecLazilyConsumer(StringRef SpecifiedName, CheckingMode Mode)
+      : Listener(SpecifiedName, Mode) {}
+
+  ASTDeserializationListener *GetASTDeserializationListener() override {
+    return &Listener;
+  }
+};
+
+class CheckLoadSpecLazilyAction : public ASTFrontendAction {
+  StringRef SpecifiedName;
+  CheckingMode Mode;
+
+public:
+  std::unique_ptr<ASTConsumer>
+  CreateASTConsumer(CompilerInstance &CI, StringRef /*Unused*/) override {
+    return std::make_unique<LoadSpecLazilyConsumer>(SpecifiedName, Mode);
+  }
+
+  CheckLoadSpecLazilyAction(StringRef SpecifiedName, CheckingMode Mode)
+      : SpecifiedName(SpecifiedName), Mode(Mode) {}
+};
+
+TEST_F(LoadSpecLazilyTest, BasicTest) {
+  GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+export class ShouldNotBeLoaded {};
+export class Temp {
+   A<ShouldNotBeLoaded> AS;
+};
+  )cpp");
+
+  const char *test_file_contents = R"cpp(
+import M;
+A<int> a;
+  )cpp";
+  std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+  EXPECT_TRUE(
+      runToolOnCodeWithArgs(std::make_unique<CheckLoadSpecLazilyAction>(
+                                "ShouldNotBeLoaded", CheckingMode::Forbidden),
+                            test_file_contents,
+                            {
+                                "-std=c++20",
+                                DepArg.c_str(),
+                                "-I",
+                                TestDir.c_str(),
+                            },
+                            "test.cpp"));
+}
+
+TEST_F(LoadSpecLazilyTest, ChainedTest) {
+  GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+  )cpp");
+
+  GenerateModuleInterface("N", R"cpp(
+export module N;
+export import M;
+export class ShouldNotBeLoaded {};
+export class Temp {
+   A<ShouldNotBeLoaded> AS;
+};
+  )cpp");
+
+  const char *test_file_contents = R"cpp(
+import N;
+A<int> a;
+  )cpp";
+  std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+  EXPECT_TRUE(
+      runToolOnCodeWithArgs(std::make_unique<CheckLoadSpecLazilyAction>(
+                                "ShouldNotBeLoaded", CheckingMode::Forbidden),
+                            test_file_contents,
+                            {
+                                "-std=c++20",
+                                DepArg.c_str(),
+                                "-I",
+                                TestDir.c_str(),
+                            },
+                            "test.cpp"));
+}
+
+/// Test that we won't crash due to we may invalidate the lazy specialization
+/// lookup table during the loading process.
+TEST_F(LoadSpecLazilyTest, ChainedTest2) {
+  GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+
+export class B {};
+
+export class C {
+  A<B> D;
+};
+  )cpp");
+
+  GenerateModuleInterface("N", R"cpp(
+export module N;
+export import M;
+export class MayBeLoaded {};
+
+export class Temp {
+   A<MayBeLoaded> AS;
+};
+
+export class ExportedClass {};
+
+export template<> class A<ExportedClass> {
+   A<MayBeLoaded> AS;
+   A<B>           AB;
+};
+  )cpp");
+
+  const char *test_file_contents = R"cpp(
+import N;
+Temp T;
+A<ExportedClass> a;
+  )cpp";
+  std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+  EXPECT_TRUE(runToolOnCodeWithArgs(std::make_unique<CheckLoadSpecLazilyAction>(
+                                        "MayBeLoaded", CheckingMode::Required),
+                                    test_file_contents,
+                                    {
+                                        "-std=c++20",
+                                        DepArg.c_str(),
+                                        "-I",
+                                        TestDir.c_str(),
+                                    },
+                                    "test.cpp"));
+}
+
+} // namespace



More information about the llvm-branch-commits mailing list