[llvm-branch-commits] [clang] [Serialization] Introduce OnDiskHashTable for specializations (PR #83233)

Chuanqi Xu via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Nov 8 01:20:07 PST 2024


https://github.com/ChuanqiXu9 updated https://github.com/llvm/llvm-project/pull/83233

>From 11726437efb760c9f2aba9b2258337b2b8eb4bb6 Mon Sep 17 00:00:00 2001
From: Chuanqi Xu <yedeng.yd at linux.alibaba.com>
Date: Fri, 8 Nov 2024 17:19:33 +0800
Subject: [PATCH] [Serialization] Introduce OnDiskHashTable for specializations

---
 clang/include/clang/AST/ExternalASTSource.h   |  11 +
 .../clang/Sema/MultiplexExternalSemaSource.h  |   6 +
 .../include/clang/Serialization/ASTBitCodes.h |   6 +
 clang/include/clang/Serialization/ASTReader.h |  34 ++-
 clang/include/clang/Serialization/ASTWriter.h |  14 +
 clang/lib/AST/DeclTemplate.cpp                |  17 ++
 clang/lib/AST/ExternalASTSource.cpp           |   5 +
 .../lib/Sema/MultiplexExternalSemaSource.cpp  |  12 +
 clang/lib/Serialization/ASTReader.cpp         | 145 +++++++++-
 clang/lib/Serialization/ASTReaderDecl.cpp     |  27 ++
 clang/lib/Serialization/ASTReaderInternals.h  | 124 +++++++++
 clang/lib/Serialization/ASTWriter.cpp         | 174 +++++++++++-
 clang/lib/Serialization/ASTWriterDecl.cpp     |  32 ++-
 clang/unittests/Serialization/CMakeLists.txt  |   1 +
 .../Serialization/LoadSpecLazilyTest.cpp      | 260 ++++++++++++++++++
 15 files changed, 855 insertions(+), 13 deletions(-)
 create mode 100644 clang/unittests/Serialization/LoadSpecLazilyTest.cpp

diff --git a/clang/include/clang/AST/ExternalASTSource.h b/clang/include/clang/AST/ExternalASTSource.h
index 582ed7c65f58ca..5f4f9a9a8d681e 100644
--- a/clang/include/clang/AST/ExternalASTSource.h
+++ b/clang/include/clang/AST/ExternalASTSource.h
@@ -152,6 +152,17 @@ class ExternalASTSource : public RefCountedBase<ExternalASTSource> {
   virtual bool
   FindExternalVisibleDeclsByName(const DeclContext *DC, DeclarationName Name);
 
+  /// Load all the external specializations for the Decl \param D if \param
+  /// OnlyPartial is false. Otherwise, load all the external **partial**
+  /// specializations for the \param D.
+  virtual void LoadExternalSpecializations(const Decl *D, bool OnlyPartial);
+
+  /// Load all the specializations for the Decl \param D with the same template
+  /// args specified by \param TemplateArgs.
+  virtual void
+  LoadExternalSpecializations(const Decl *D,
+                              ArrayRef<TemplateArgument> TemplateArgs);
+
   /// Ensures that the table of all visible declarations inside this
   /// context is up to date.
   ///
diff --git a/clang/include/clang/Sema/MultiplexExternalSemaSource.h b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
index 3d1906d8699265..78bbbaf2d7b5c6 100644
--- a/clang/include/clang/Sema/MultiplexExternalSemaSource.h
+++ b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
@@ -97,6 +97,12 @@ class MultiplexExternalSemaSource : public ExternalSemaSource {
   bool FindExternalVisibleDeclsByName(const DeclContext *DC,
                                       DeclarationName Name) override;
 
+  void LoadExternalSpecializations(const Decl *D, bool OnlyPartial) override;
+
+  void
+  LoadExternalSpecializations(const Decl *D,
+                              ArrayRef<TemplateArgument> TemplateArgs) override;
+
   /// Ensures that the table of all visible declarations inside this
   /// context is up to date.
   void completeVisibleDeclsMap(const DeclContext *DC) override;
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index 3b14a0b8203315..cb3ed6c1ecbb7c 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -734,6 +734,9 @@ enum ASTRecordTypes {
   /// Record code for Sema's vector of functions/blocks with effects to
   /// be verified.
   DECLS_WITH_EFFECTS_TO_VERIFY = 72,
+
+  /// Record code for updated specialization
+  UPDATE_SPECIALIZATION = 73,
 };
 
 /// Record types used within a source manager block.
@@ -1500,6 +1503,9 @@ enum DeclCode {
   /// A HLSLBufferDecl record.
   DECL_HLSL_BUFFER,
 
+  // A decls specilization record.
+  DECL_SPECIALIZATIONS,
+
   /// An ImplicitConceptSpecializationDecl record.
   DECL_IMPLICIT_CONCEPT_SPECIALIZATION,
 
diff --git a/clang/include/clang/Serialization/ASTReader.h b/clang/include/clang/Serialization/ASTReader.h
index 9c274adc59a207..6306d4f08e81fa 100644
--- a/clang/include/clang/Serialization/ASTReader.h
+++ b/clang/include/clang/Serialization/ASTReader.h
@@ -354,6 +354,9 @@ class ASTIdentifierLookupTrait;
 /// The on-disk hash table(s) used for DeclContext name lookup.
 struct DeclContextLookupTable;
 
+/// The on-disk hash table(s) used for specialization decls.
+struct LazySpecializationInfoLookupTable;
+
 } // namespace reader
 
 } // namespace serialization
@@ -632,20 +635,29 @@ class ASTReader
   llvm::DenseMap<const DeclContext *,
                  serialization::reader::DeclContextLookupTable> Lookups;
 
+  /// Map from decls to specialized decls.
+  llvm::DenseMap<const Decl *,
+                 serialization::reader::LazySpecializationInfoLookupTable>
+      SpecializationsLookups;
+
   // Updates for visible decls can occur for other contexts than just the
   // TU, and when we read those update records, the actual context may not
   // be available yet, so have this pending map using the ID as a key. It
-  // will be realized when the context is actually loaded.
-  struct PendingVisibleUpdate {
+  // will be realized when the data is actually loaded.
+  struct UpdateData {
     ModuleFile *Mod;
     const unsigned char *Data;
   };
-  using DeclContextVisibleUpdates = SmallVector<PendingVisibleUpdate, 1>;
+  using DeclContextVisibleUpdates = SmallVector<UpdateData, 1>;
 
   /// Updates to the visible declarations of declaration contexts that
   /// haven't been loaded yet.
   llvm::DenseMap<GlobalDeclID, DeclContextVisibleUpdates> PendingVisibleUpdates;
 
+  using SpecializationsUpdate = SmallVector<UpdateData, 1>;
+  llvm::DenseMap<GlobalDeclID, SpecializationsUpdate>
+      PendingSpecializationsUpdates;
+
   /// The set of C++ or Objective-C classes that have forward
   /// declarations that have not yet been linked to their definitions.
   llvm::SmallPtrSet<Decl *, 4> PendingDefinitions;
@@ -678,6 +690,11 @@ class ASTReader
                                      llvm::BitstreamCursor &Cursor,
                                      uint64_t Offset, GlobalDeclID ID);
 
+  bool ReadSpecializations(ModuleFile &M, llvm::BitstreamCursor &Cursor,
+                           uint64_t Offset, Decl *D);
+  void AddSpecializations(const Decl *D, const unsigned char *Data,
+                          ModuleFile &M);
+
   /// A vector containing identifiers that have already been
   /// loaded.
   ///
@@ -1419,6 +1436,11 @@ class ASTReader
   const serialization::reader::DeclContextLookupTable *
   getLoadedLookupTables(DeclContext *Primary) const;
 
+  /// Get the loaded specializations lookup tables for \p D,
+  /// if any.
+  serialization::reader::LazySpecializationInfoLookupTable *
+  getLoadedSpecializationsLookupTables(const Decl *D);
+
 private:
   struct ImportedModule {
     ModuleFile *Mod;
@@ -2076,6 +2098,12 @@ class ASTReader
                                       unsigned BlockID,
                                       uint64_t *StartOfBlockOffset = nullptr);
 
+  void LoadExternalSpecializations(const Decl *D, bool OnlyPartial) override;
+
+  void
+  LoadExternalSpecializations(const Decl *D,
+                              ArrayRef<TemplateArgument> TemplateArgs) override;
+
   /// Finds all the visible declarations with a given name.
   /// The current implementation of this method just loads the entire
   /// lookup table as unmaterialized references.
diff --git a/clang/include/clang/Serialization/ASTWriter.h b/clang/include/clang/Serialization/ASTWriter.h
index dc9fcd3c33726e..9da22f96130a8b 100644
--- a/clang/include/clang/Serialization/ASTWriter.h
+++ b/clang/include/clang/Serialization/ASTWriter.h
@@ -423,6 +423,12 @@ class ASTWriter : public ASTDeserializationListener,
   /// Only meaningful for reduced BMI.
   DeclUpdateMap DeclUpdatesFromGMF;
 
+  /// Mapping from decl templates and its new specialization in the
+  /// current TU.
+  using SpecializationUpdateMap =
+      llvm::MapVector<const NamedDecl *, SmallVector<const Decl *>>;
+  SpecializationUpdateMap SpecializationsUpdates;
+
   using FirstLatestDeclMap = llvm::DenseMap<Decl *, Decl *>;
 
   /// Map of first declarations from a chained PCH that point to the
@@ -572,6 +578,11 @@ class ASTWriter : public ASTDeserializationListener,
 
   bool isLookupResultExternal(StoredDeclsList &Result, DeclContext *DC);
 
+  void GenerateSpecializationInfoLookupTable(
+      const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations,
+      llvm::SmallVectorImpl<char> &LookupTable);
+  uint64_t WriteSpecializationInfoLookupTable(
+      const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations);
   void GenerateNameLookupTable(ASTContext &Context, const DeclContext *DC,
                                llvm::SmallVectorImpl<char> &LookupTable);
   uint64_t WriteDeclContextLexicalBlock(ASTContext &Context,
@@ -587,6 +598,7 @@ class ASTWriter : public ASTDeserializationListener,
   void WriteDeclAndTypes(ASTContext &Context);
   void PrepareWritingSpecialDecls(Sema &SemaRef);
   void WriteSpecialDeclRecords(Sema &SemaRef);
+  void WriteSpecializationsUpdates();
   void WriteDeclUpdatesBlocks(ASTContext &Context,
                               RecordDataImpl &OffsetsRecord);
   void WriteDeclContextVisibleUpdate(ASTContext &Context,
@@ -616,6 +628,8 @@ class ASTWriter : public ASTDeserializationListener,
   unsigned DeclEnumAbbrev = 0;
   unsigned DeclObjCIvarAbbrev = 0;
   unsigned DeclCXXMethodAbbrev = 0;
+  unsigned DeclSpecializationsAbbrev = 0;
+
   unsigned DeclDependentNonTemplateCXXMethodAbbrev = 0;
   unsigned DeclTemplateCXXMethodAbbrev = 0;
   unsigned DeclMemberSpecializedCXXMethodAbbrev = 0;
diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp
index d0e16c30881be2..a73af9c7785320 100644
--- a/clang/lib/AST/DeclTemplate.cpp
+++ b/clang/lib/AST/DeclTemplate.cpp
@@ -355,6 +355,14 @@ RedeclarableTemplateDecl::CommonBase *RedeclarableTemplateDecl::getCommonPtr() c
 
 void RedeclarableTemplateDecl::loadLazySpecializationsImpl(
     bool OnlyPartial /*=false*/) const {
+  auto *ExternalSource = getASTContext().getExternalSource();
+  if (!ExternalSource)
+    return;
+
+  ExternalSource->LoadExternalSpecializations(this->getCanonicalDecl(),
+                                              OnlyPartial);
+  return;
+
   // Grab the most recent declaration to ensure we've loaded any lazy
   // redeclarations of this template.
   CommonBase *CommonBasePtr = getMostRecentDecl()->getCommonPtr();
@@ -374,6 +382,8 @@ void RedeclarableTemplateDecl::loadLazySpecializationsImpl(
 
 Decl *RedeclarableTemplateDecl::loadLazySpecializationImpl(
     LazySpecializationInfo &LazySpecInfo) const {
+  llvm_unreachable("We don't use LazySpecializationInfo any more");
+
   GlobalDeclID ID = LazySpecInfo.DeclID;
   assert(ID.isValid() && "Loading already loaded specialization!");
   // Note that we loaded the specialization.
@@ -384,6 +394,13 @@ Decl *RedeclarableTemplateDecl::loadLazySpecializationImpl(
 
 void RedeclarableTemplateDecl::loadLazySpecializationsImpl(
     ArrayRef<TemplateArgument> Args, TemplateParameterList *TPL) const {
+  auto *ExternalSource = getASTContext().getExternalSource();
+  if (!ExternalSource)
+    return;
+
+  ExternalSource->LoadExternalSpecializations(this->getCanonicalDecl(), Args);
+  return;
+
   CommonBase *CommonBasePtr = getMostRecentDecl()->getCommonPtr();
   if (auto *Specs = CommonBasePtr->LazySpecializations) {
     unsigned Hash = TemplateArgumentList::ComputeODRHash(Args);
diff --git a/clang/lib/AST/ExternalASTSource.cpp b/clang/lib/AST/ExternalASTSource.cpp
index a5b6f80bde694c..122014bfeb2321 100644
--- a/clang/lib/AST/ExternalASTSource.cpp
+++ b/clang/lib/AST/ExternalASTSource.cpp
@@ -98,6 +98,11 @@ ExternalASTSource::FindExternalVisibleDeclsByName(const DeclContext *DC,
   return false;
 }
 
+void ExternalASTSource::LoadExternalSpecializations(const Decl *D, bool) {}
+
+void ExternalASTSource::LoadExternalSpecializations(
+    const Decl *D, ArrayRef<TemplateArgument>) {}
+
 void ExternalASTSource::completeVisibleDeclsMap(const DeclContext *DC) {}
 
 void ExternalASTSource::FindExternalLexicalDecls(
diff --git a/clang/lib/Sema/MultiplexExternalSemaSource.cpp b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
index cd44483b5cbe04..f39463712ce81f 100644
--- a/clang/lib/Sema/MultiplexExternalSemaSource.cpp
+++ b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
@@ -115,6 +115,18 @@ FindExternalVisibleDeclsByName(const DeclContext *DC, DeclarationName Name) {
   return AnyDeclsFound;
 }
 
+void MultiplexExternalSemaSource::LoadExternalSpecializations(
+    const Decl *D, bool OnlyPartial) {
+  for (size_t i = 0; i < Sources.size(); ++i)
+    Sources[i]->LoadExternalSpecializations(D, OnlyPartial);
+}
+
+void MultiplexExternalSemaSource::LoadExternalSpecializations(
+    const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+  for (size_t i = 0; i < Sources.size(); ++i)
+    Sources[i]->LoadExternalSpecializations(D, TemplateArgs);
+}
+
 void MultiplexExternalSemaSource::completeVisibleDeclsMap(const DeclContext *DC){
   for(size_t i = 0; i < Sources.size(); ++i)
     Sources[i]->completeVisibleDeclsMap(DC);
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 8bba98ef33302b..75d81a25dd3ac3 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -1283,6 +1283,43 @@ void ASTDeclContextNameLookupTrait::ReadDataInto(internal_key_type,
   }
 }
 
+ModuleFile *
+LazySpecializationInfoLookupTrait::ReadFileRef(const unsigned char *&d) {
+  using namespace llvm::support;
+
+  uint32_t ModuleFileID =
+      endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+  return Reader.getLocalModuleFile(F, ModuleFileID);
+}
+
+LazySpecializationInfoLookupTrait::internal_key_type
+LazySpecializationInfoLookupTrait::ReadKey(const unsigned char *d, unsigned) {
+  using namespace llvm::support;
+  return endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+}
+
+std::pair<unsigned, unsigned>
+LazySpecializationInfoLookupTrait::ReadKeyDataLength(const unsigned char *&d) {
+  return readULEBKeyDataLength(d);
+}
+
+void LazySpecializationInfoLookupTrait::ReadDataInto(internal_key_type,
+                                                     const unsigned char *d,
+                                                     unsigned DataLen,
+                                                     data_type_builder &Val) {
+  using namespace llvm::support;
+
+  for (unsigned NumDecls =
+           DataLen / serialization::reader::LazySpecializationInfo::Length;
+       NumDecls; --NumDecls) {
+    LocalDeclID LocalID =
+        LocalDeclID::get(Reader, F, endian::readNext<DeclID, llvm::endianness::little, unaligned>(d));
+    const bool IsPartial =
+        endian::readNext<bool, llvm::endianness::little, unaligned>(d);
+    Val.insert({Reader.getGlobalDeclID(F, LocalID), IsPartial});
+  }
+}
+
 bool ASTReader::ReadLexicalDeclContextStorage(ModuleFile &M,
                                               BitstreamCursor &Cursor,
                                               uint64_t Offset,
@@ -1367,7 +1404,49 @@ bool ASTReader::ReadVisibleDeclContextStorage(ModuleFile &M,
   // We can't safely determine the primary context yet, so delay attaching the
   // lookup table until we're done with recursive deserialization.
   auto *Data = (const unsigned char*)Blob.data();
-  PendingVisibleUpdates[ID].push_back(PendingVisibleUpdate{&M, Data});
+  PendingVisibleUpdates[ID].push_back(UpdateData{&M, Data});
+  return false;
+}
+
+void ASTReader::AddSpecializations(const Decl *D, const unsigned char *Data,
+                                   ModuleFile &M) {
+  D = D->getCanonicalDecl();
+  SpecializationsLookups[D].Table.add(
+      &M, Data, reader::LazySpecializationInfoLookupTrait(*this, M));
+}
+
+bool ASTReader::ReadSpecializations(ModuleFile &M, BitstreamCursor &Cursor,
+                                    uint64_t Offset, Decl *D) {
+  assert(Offset != 0);
+
+  SavedStreamPosition SavedPosition(Cursor);
+  if (llvm::Error Err = Cursor.JumpToBit(Offset)) {
+    Error(std::move(Err));
+    return true;
+  }
+
+  RecordData Record;
+  StringRef Blob;
+  Expected<unsigned> MaybeCode = Cursor.ReadCode();
+  if (!MaybeCode) {
+    Error(MaybeCode.takeError());
+    return true;
+  }
+  unsigned Code = MaybeCode.get();
+
+  Expected<unsigned> MaybeRecCode = Cursor.readRecord(Code, Record, &Blob);
+  if (!MaybeRecCode) {
+    Error(MaybeRecCode.takeError());
+    return true;
+  }
+  unsigned RecCode = MaybeRecCode.get();
+  if (RecCode != DECL_SPECIALIZATIONS) {
+    Error("Expected decl specs block");
+    return true;
+  }
+
+  auto *Data = (const unsigned char *)Blob.data();
+  AddSpecializations(D, Data, M);
   return false;
 }
 
@@ -3454,7 +3533,20 @@ llvm::Error ASTReader::ReadASTBlock(ModuleFile &F,
       unsigned Idx = 0;
       GlobalDeclID ID = ReadDeclID(F, Record, Idx);
       auto *Data = (const unsigned char*)Blob.data();
-      PendingVisibleUpdates[ID].push_back(PendingVisibleUpdate{&F, Data});
+      PendingVisibleUpdates[ID].push_back(UpdateData{&F, Data});
+      // If we've already loaded the decl, perform the updates when we finish
+      // loading this block.
+      if (Decl *D = GetExistingDecl(ID))
+        PendingUpdateRecords.push_back(
+            PendingUpdateRecord(ID, D, /*JustLoaded=*/false));
+      break;
+    }
+
+    case UPDATE_SPECIALIZATION: {
+      unsigned Idx = 0;
+      GlobalDeclID ID = ReadDeclID(F, Record, Idx);
+      auto *Data = (const unsigned char *)Blob.data();
+      PendingSpecializationsUpdates[ID].push_back(UpdateData{&F, Data});
       // If we've already loaded the decl, perform the updates when we finish
       // loading this block.
       if (Decl *D = GetExistingDecl(ID))
@@ -8037,6 +8129,48 @@ Stmt *ASTReader::GetExternalDeclStmt(uint64_t Offset) {
   return ReadStmtFromStream(*Loc.F);
 }
 
+void ASTReader::LoadExternalSpecializations(const Decl *D, bool OnlyPartial) {
+  assert(D);
+
+  auto It = SpecializationsLookups.find(D);
+  if (It == SpecializationsLookups.end())
+    return;
+
+  // Get Decl may violate the iterator from SpecializationsLookups so we store
+  // the DeclIDs in ahead.
+  llvm::SmallVector<serialization::reader::LazySpecializationInfo, 8> Infos =
+      It->second.Table.findAll();
+
+  // Since we've loaded all the specializations, we can erase it from
+  // the lookup table.
+  if (!OnlyPartial)
+    SpecializationsLookups.erase(It);
+
+  Deserializing LookupResults(this);
+  for (auto &Info : Infos)
+    if (!OnlyPartial || Info.IsPartial)
+      GetDecl(Info.ID);
+}
+
+void ASTReader::LoadExternalSpecializations(
+    const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+  assert(D);
+
+  auto It = SpecializationsLookups.find(D);
+  if (It == SpecializationsLookups.end())
+    return;
+
+  Deserializing LookupResults(this);
+  auto HashValue = TemplateArgumentList::ComputeODRHash(TemplateArgs);
+
+  // Get Decl may violate the iterator from SpecializationsLookups
+  llvm::SmallVector<serialization::reader::LazySpecializationInfo, 8> Infos =
+      It->second.Table.find(HashValue);
+
+  for (auto &Info : Infos)
+    GetDecl(Info.ID);
+}
+
 void ASTReader::FindExternalLexicalDecls(
     const DeclContext *DC, llvm::function_ref<bool(Decl::Kind)> IsKindWeWant,
     SmallVectorImpl<Decl *> &Decls) {
@@ -8216,6 +8350,13 @@ ASTReader::getLoadedLookupTables(DeclContext *Primary) const {
   return I == Lookups.end() ? nullptr : &I->second;
 }
 
+serialization::reader::LazySpecializationInfoLookupTable *
+ASTReader::getLoadedSpecializationsLookupTables(const Decl *D) {
+  assert(D->isCanonicalDecl());
+  auto I = SpecializationsLookups.find(D);
+  return I == SpecializationsLookups.end() ? nullptr : &I->second;
+}
+
 /// Under non-PCH compilation the consumer receives the objc methods
 /// before receiving the implementation, and codegen depends on this.
 /// We simulate this by deserializing and passing to consumer the methods of the
diff --git a/clang/lib/Serialization/ASTReaderDecl.cpp b/clang/lib/Serialization/ASTReaderDecl.cpp
index 6816bbcd45dcbe..87c491a977ffeb 100644
--- a/clang/lib/Serialization/ASTReaderDecl.cpp
+++ b/clang/lib/Serialization/ASTReaderDecl.cpp
@@ -342,6 +342,9 @@ class ASTDeclReader : public DeclVisitor<ASTDeclReader, void> {
   static void markIncompleteDeclChainImpl(Redeclarable<DeclT> *D);
   static void markIncompleteDeclChainImpl(...);
 
+  void ReadSpecializations(ModuleFile &M, Decl *D,
+                           llvm::BitstreamCursor &DeclsCursor);
+
   void ReadFunctionDefinition(FunctionDecl *FD);
   void Visit(Decl *D);
 
@@ -2430,6 +2433,14 @@ void ASTDeclReader::VisitImplicitConceptSpecializationDecl(
 void ASTDeclReader::VisitRequiresExprBodyDecl(RequiresExprBodyDecl *D) {
 }
 
+void ASTDeclReader::ReadSpecializations(ModuleFile &M, Decl *D,
+                                        llvm::BitstreamCursor &DeclsCursor) {
+  uint64_t Offset = ReadLocalOffset();
+  bool Failed = Reader.ReadSpecializations(M, DeclsCursor, Offset, D);
+  (void)Failed;
+  assert(!Failed);
+}
+
 RedeclarableResult
 ASTDeclReader::VisitRedeclarableTemplateDecl(RedeclarableTemplateDecl *D) {
   RedeclarableResult Redecl = VisitRedeclarable(D);
@@ -2471,6 +2482,7 @@ void ASTDeclReader::VisitClassTemplateDecl(ClassTemplateDecl *D) {
     SmallVector<LazySpecializationInfo, 32> SpecIDs;
     readDeclIDList(SpecIDs);
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
+    ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor);
   }
 
   if (D->getTemplatedDecl()->TemplateOrInstantiation) {
@@ -2499,6 +2511,7 @@ void ASTDeclReader::VisitVarTemplateDecl(VarTemplateDecl *D) {
     SmallVector<LazySpecializationInfo, 32> SpecIDs;
     readDeclIDList(SpecIDs);
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
+    ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor);
   }
 }
 
@@ -2600,6 +2613,7 @@ void ASTDeclReader::VisitFunctionTemplateDecl(FunctionTemplateDecl *D) {
     SmallVector<LazySpecializationInfo, 32> SpecIDs;
     readDeclIDList(SpecIDs);
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
+    ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor);
   }
 }
 
@@ -3889,6 +3903,7 @@ Decl *ASTReader::ReadDeclRecord(GlobalDeclID ID) {
   switch ((DeclCode)MaybeDeclCode.get()) {
   case DECL_CONTEXT_LEXICAL:
   case DECL_CONTEXT_VISIBLE:
+  case DECL_SPECIALIZATIONS:
     llvm_unreachable("Record cannot be de-serialized with readDeclRecord");
   case DECL_TYPEDEF:
     D = TypedefDecl::CreateDeserialized(Context, ID);
@@ -4352,12 +4367,14 @@ void ASTReader::loadDeclUpdateRecords(PendingUpdateRecord &Record) {
   assert((PendingLazySpecializationIDs.empty() || isa<ClassTemplateDecl>(D) ||
           isa<FunctionTemplateDecl, VarTemplateDecl>(D)) &&
          "Must not have pending specializations");
+  /*
   if (auto *CTD = dyn_cast<ClassTemplateDecl>(D))
     ASTDeclReader::AddLazySpecializations(CTD, PendingLazySpecializationIDs);
   else if (auto *FTD = dyn_cast<FunctionTemplateDecl>(D))
     ASTDeclReader::AddLazySpecializations(FTD, PendingLazySpecializationIDs);
   else if (auto *VTD = dyn_cast<VarTemplateDecl>(D))
     ASTDeclReader::AddLazySpecializations(VTD, PendingLazySpecializationIDs);
+  */
   PendingLazySpecializationIDs.clear();
 
   // Load the pending visible updates for this decl context, if it has any.
@@ -4383,6 +4400,16 @@ void ASTReader::loadDeclUpdateRecords(PendingUpdateRecord &Record) {
       FunctionToLambdasMap.erase(IT);
     }
   }
+
+  // Load the pending specializations update for this decl, if it has any.
+  if (auto I = PendingSpecializationsUpdates.find(ID);
+      I != PendingSpecializationsUpdates.end()) {
+    auto SpecializationUpdates = std::move(I->second);
+    PendingSpecializationsUpdates.erase(I);
+
+    for (const auto &Update : SpecializationUpdates)
+      AddSpecializations(D, Update.Data, *Update.Mod);
+  }
 }
 
 void ASTReader::loadPendingDeclChain(Decl *FirstLocal, uint64_t LocalOffset) {
diff --git a/clang/lib/Serialization/ASTReaderInternals.h b/clang/lib/Serialization/ASTReaderInternals.h
index 4f7e6f4b2741b7..b921d8d174c3a2 100644
--- a/clang/lib/Serialization/ASTReaderInternals.h
+++ b/clang/lib/Serialization/ASTReaderInternals.h
@@ -119,6 +119,102 @@ struct DeclContextLookupTable {
   MultiOnDiskHashTable<ASTDeclContextNameLookupTrait> Table;
 };
 
+struct LazySpecializationInfo {
+  // The Decl ID for the specialization.
+  GlobalDeclID ID;
+  // Whether or not this specialization is partial.
+  bool IsPartial;
+
+  bool operator==(const LazySpecializationInfo &Other) {
+    assert(ID != Other.ID || IsPartial == Other.IsPartial);
+    return ID == Other.ID;
+  }
+
+  // Records the size record in OnDiskHashTable.
+  // sizeof() may return 8 due to align requirements.
+  static constexpr unsigned Length = sizeof(DeclID) + sizeof(IsPartial);
+};
+
+/// Class that performs lookup to specialized decls.
+class LazySpecializationInfoLookupTrait {
+  ASTReader &Reader;
+  ModuleFile &F;
+
+public:
+  // Maximum number of lookup tables we allow before condensing the tables.
+  static const int MaxTables = 4;
+
+  /// The lookup result is a list of global declaration IDs.
+  using data_type = SmallVector<LazySpecializationInfo, 4>;
+
+  struct data_type_builder {
+    data_type &Data;
+    llvm::DenseSet<LazySpecializationInfo> Found;
+
+    data_type_builder(data_type &D) : Data(D) {}
+
+    void insert(LazySpecializationInfo Info) {
+      // Just use a linear scan unless we have more than a few IDs.
+      if (Found.empty() && !Data.empty()) {
+        if (Data.size() <= 4) {
+          for (auto I : Found)
+            if (I == Info)
+              return;
+          Data.push_back(Info);
+          return;
+        }
+
+        // Switch to tracking found IDs in the set.
+        Found.insert(Data.begin(), Data.end());
+      }
+
+      if (Found.insert(Info).second)
+        Data.push_back(Info);
+    }
+  };
+  using hash_value_type = unsigned;
+  using offset_type = unsigned;
+  using file_type = ModuleFile *;
+
+  using external_key_type = unsigned;
+  using internal_key_type = unsigned;
+
+  explicit LazySpecializationInfoLookupTrait(ASTReader &Reader, ModuleFile &F)
+      : Reader(Reader), F(F) {}
+
+  static bool EqualKey(const internal_key_type &a, const internal_key_type &b) {
+    return a == b;
+  }
+
+  static hash_value_type ComputeHash(const internal_key_type &Key) {
+    return Key;
+  }
+
+  static internal_key_type GetInternalKey(const external_key_type &Name) {
+    return Name;
+  }
+
+  static std::pair<unsigned, unsigned>
+  ReadKeyDataLength(const unsigned char *&d);
+
+  internal_key_type ReadKey(const unsigned char *d, unsigned);
+
+  void ReadDataInto(internal_key_type, const unsigned char *d, unsigned DataLen,
+                    data_type_builder &Val);
+
+  static void MergeDataInto(const data_type &From, data_type_builder &To) {
+    To.Data.reserve(To.Data.size() + From.size());
+    for (LazySpecializationInfo Info : From)
+      To.insert(Info);
+  }
+
+  file_type ReadFileRef(const unsigned char *&d);
+};
+
+struct LazySpecializationInfoLookupTable {
+  MultiOnDiskHashTable<LazySpecializationInfoLookupTrait> Table;
+};
+
 /// Base class for the trait describing the on-disk hash table for the
 /// identifiers in an AST file.
 ///
@@ -288,4 +384,32 @@ using HeaderFileInfoLookupTable =
 
 } // namespace clang
 
+namespace llvm {
+// ID is unique in LazySpecializationInfo, it is redundant to calculate
+// IsPartial.
+template <>
+struct DenseMapInfo<clang::serialization::reader::LazySpecializationInfo> {
+  using LazySpecializationInfo =
+      clang::serialization::reader::LazySpecializationInfo;
+  using Wrapped = DenseMapInfo<clang::serialization::DeclID>;
+
+  static inline LazySpecializationInfo getEmptyKey() {
+    return {(clang::GlobalDeclID)Wrapped::getEmptyKey(), false};
+  }
+
+  static inline LazySpecializationInfo getTombstoneKey() {
+    return {(clang::GlobalDeclID)Wrapped::getTombstoneKey(), false};
+  }
+
+  static unsigned getHashValue(const LazySpecializationInfo &Key) {
+    return Wrapped::getHashValue(Key.ID.getRawValue());
+  }
+
+  static bool isEqual(const LazySpecializationInfo &LHS,
+                      const LazySpecializationInfo &RHS) {
+    return LHS.ID == RHS.ID;
+  }
+};
+} // end namespace llvm
+
 #endif // LLVM_CLANG_LIB_SERIALIZATION_ASTREADERINTERNALS_H
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 455e8f0e749435..adfe9fc8b369d6 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -4106,6 +4106,155 @@ class ASTDeclContextNameLookupTrait {
 
 } // namespace
 
+namespace {
+class LazySpecializationInfoLookupTrait {
+  ASTWriter &Writer;
+  llvm::SmallVector<serialization::reader::LazySpecializationInfo, 64> Specs;
+
+public:
+  using key_type = unsigned;
+  using key_type_ref = key_type;
+
+  /// A start and end index into Specs, representing a sequence of decls.
+  using data_type = std::pair<unsigned, unsigned>;
+  using data_type_ref = const data_type &;
+
+  using hash_value_type = unsigned;
+  using offset_type = unsigned;
+
+  explicit LazySpecializationInfoLookupTrait(ASTWriter &Writer)
+      : Writer(Writer) {}
+
+  template <typename Col> data_type getData(Col &&C) {
+    unsigned Start = Specs.size();
+    for (auto *D : C) {
+      bool IsPartial = isa<ClassTemplatePartialSpecializationDecl,
+                           VarTemplatePartialSpecializationDecl>(D);
+      NamedDecl *ND = getDeclForLocalLookup(
+                           Writer.getLangOpts(), const_cast<NamedDecl *>(D));
+      Specs.push_back({GlobalDeclID(Writer.GetDeclRef(ND).getRawValue()),
+                       IsPartial});
+    }
+    return std::make_pair(Start, Specs.size());
+  }
+
+  data_type ImportData(
+      const reader::LazySpecializationInfoLookupTrait::data_type &FromReader) {
+    unsigned Start = Specs.size();
+    for (auto ID : FromReader)
+      Specs.push_back(ID);
+    return std::make_pair(Start, Specs.size());
+  }
+
+  static bool EqualKey(key_type_ref a, key_type_ref b) { return a == b; }
+
+  hash_value_type ComputeHash(key_type Name) { return Name; }
+
+  void EmitFileRef(raw_ostream &Out, ModuleFile *F) const {
+    assert(Writer.hasChain() &&
+           "have reference to loaded module file but no chain?");
+
+    using namespace llvm::support;
+    endian::write<uint32_t>(Out, Writer.getChain()->getModuleFileID(F),
+                            llvm::endianness::little);
+  }
+
+  std::pair<unsigned, unsigned> EmitKeyDataLength(raw_ostream &Out,
+                                                  key_type HashValue,
+                                                  data_type_ref Lookup) {
+    // 4 bytes for each slot.
+    unsigned KeyLen = 4;
+    unsigned DataLen = serialization::reader::LazySpecializationInfo::Length *
+                       (Lookup.second - Lookup.first);
+
+    return emitULEBKeyDataLength(KeyLen, DataLen, Out);
+  }
+
+  void EmitKey(raw_ostream &Out, key_type HashValue, unsigned) {
+    using namespace llvm::support;
+
+    endian::Writer LE(Out, llvm::endianness::little);
+    LE.write<uint32_t>(HashValue);
+  }
+
+  void EmitData(raw_ostream &Out, key_type_ref, data_type Lookup,
+                unsigned DataLen) {
+    using namespace llvm::support;
+
+    endian::Writer LE(Out, llvm::endianness::little);
+    uint64_t Start = Out.tell();
+    (void)Start;
+    for (unsigned I = Lookup.first, N = Lookup.second; I != N; ++I) {
+      LE.write<DeclID>(Specs[I].ID.getRawValue());
+      LE.write<bool>(Specs[I].IsPartial);
+    }
+    assert(Out.tell() - Start == DataLen && "Data length is wrong");
+  }
+};
+
+unsigned CalculateODRHashForSpecs(const Decl *Spec) {
+  ArrayRef<TemplateArgument> Args;
+  if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(Spec))
+    Args = CTSD->getTemplateArgs().asArray();
+  else if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(Spec))
+    Args = VTSD->getTemplateArgs().asArray();
+  else if (auto *FD = dyn_cast<FunctionDecl>(Spec))
+    Args = FD->getTemplateSpecializationArgs()->asArray();
+  else
+    llvm_unreachable("New Specialization Kind?");
+
+  return TemplateArgumentList::ComputeODRHash(Args);
+}
+} // namespace
+
+void ASTWriter::GenerateSpecializationInfoLookupTable(
+    const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations,
+    llvm::SmallVectorImpl<char> &LookupTable) {
+  assert(D->isFirstDecl());
+
+  // Create the on-disk hash table representation.
+  MultiOnDiskHashTableGenerator<reader::LazySpecializationInfoLookupTrait,
+                                LazySpecializationInfoLookupTrait>
+      Generator;
+  LazySpecializationInfoLookupTrait Trait(*this);
+
+  llvm::DenseMap<unsigned, llvm::SmallVector<const NamedDecl *, 4>>
+      SpecializationMaps;
+
+  for (auto *Specialization : Specializations) {
+    unsigned HashedValue = CalculateODRHashForSpecs(Specialization);
+
+    auto Iter = SpecializationMaps.find(HashedValue);
+    if (Iter == SpecializationMaps.end())
+      Iter = SpecializationMaps
+                 .try_emplace(HashedValue,
+                              llvm::SmallVector<const NamedDecl *, 4>())
+                 .first;
+
+    Iter->second.push_back(cast<NamedDecl>(Specialization));
+  }
+
+  for (auto Iter : SpecializationMaps)
+    Generator.insert(Iter.first, Trait.getData(Iter.second), Trait);
+
+  auto *Lookups =
+      Chain ? Chain->getLoadedSpecializationsLookupTables(D) : nullptr;
+  Generator.emit(LookupTable, Trait, Lookups ? &Lookups->Table : nullptr);
+}
+
+uint64_t ASTWriter::WriteSpecializationInfoLookupTable(
+    const NamedDecl *D, llvm::SmallVectorImpl<const Decl *> &Specializations) {
+
+  llvm::SmallString<4096> LookupTable;
+  GenerateSpecializationInfoLookupTable(D, Specializations, LookupTable);
+
+  uint64_t Offset = Stream.GetCurrentBitNo();
+  RecordData::value_type Record[] = {DECL_SPECIALIZATIONS};
+  Stream.EmitRecordWithBlob(DeclSpecializationsAbbrev, Record, LookupTable);
+
+  return Offset;
+}
+
 bool ASTWriter::isLookupResultExternal(StoredDeclsList &Result,
                                        DeclContext *DC) {
   return Result.hasExternalDecls() &&
@@ -5650,7 +5799,7 @@ void ASTWriter::WriteDeclAndTypes(ASTContext &Context) {
   // Keep writing types, declarations, and declaration update records
   // until we've emitted all of them.
   RecordData DeclUpdatesOffsetsRecord;
-  Stream.EnterSubblock(DECLTYPES_BLOCK_ID, /*bits for abbreviations*/5);
+  Stream.EnterSubblock(DECLTYPES_BLOCK_ID, /*bits for abbreviations*/6);
   DeclTypesBlockStartOffset = Stream.GetCurrentBitNo();
   WriteTypeAbbrevs();
   WriteDeclAbbrevs();
@@ -5724,6 +5873,9 @@ void ASTWriter::WriteDeclAndTypes(ASTContext &Context) {
                       FunctionToLambdaMapAbbrev);
   }
 
+  if (!SpecializationsUpdates.empty())
+    WriteSpecializationsUpdates();
+
   const TranslationUnitDecl *TU = Context.getTranslationUnitDecl();
   // Create a lexical update block containing all of the declarations in the
   // translation unit that do not come from other AST files.
@@ -5767,6 +5919,26 @@ void ASTWriter::WriteDeclAndTypes(ASTContext &Context) {
     WriteDeclContextVisibleUpdate(Context, DC);
 }
 
+void ASTWriter::WriteSpecializationsUpdates() {
+  auto Abv = std::make_shared<llvm::BitCodeAbbrev>();
+  Abv->Add(llvm::BitCodeAbbrevOp(UPDATE_SPECIALIZATION));
+  Abv->Add(llvm::BitCodeAbbrevOp(llvm::BitCodeAbbrevOp::VBR, 6));
+  Abv->Add(llvm::BitCodeAbbrevOp(llvm::BitCodeAbbrevOp::Blob));
+  auto UpdateSpecializationAbbrev = Stream.EmitAbbrev(std::move(Abv));
+
+  for (auto &SpecializationUpdate : SpecializationsUpdates) {
+    const NamedDecl *D = SpecializationUpdate.first;
+
+    llvm::SmallString<4096> LookupTable;
+    GenerateSpecializationInfoLookupTable(D, SpecializationUpdate.second,
+                                          LookupTable);
+
+    // Write the lookup table
+    RecordData::value_type Record[] = {UPDATE_SPECIALIZATION, getDeclID(D).getRawValue()};
+    Stream.EmitRecordWithBlob(UpdateSpecializationAbbrev, Record, LookupTable);
+  }
+}
+
 void ASTWriter::WriteDeclUpdatesBlocks(ASTContext &Context,
                                        RecordDataImpl &OffsetsRecord) {
   if (DeclUpdates.empty())
diff --git a/clang/lib/Serialization/ASTWriterDecl.cpp b/clang/lib/Serialization/ASTWriterDecl.cpp
index 71a9bd1aefade6..14c2f23700e8c3 100644
--- a/clang/lib/Serialization/ASTWriterDecl.cpp
+++ b/clang/lib/Serialization/ASTWriterDecl.cpp
@@ -207,15 +207,17 @@ namespace clang {
     /// file that provides a declaration of D. We store the DeclId and an
     /// ODRHash of the template arguments of D which should provide enough
     /// information to load D only if the template instantiator needs it.
-    void AddFirstSpecializationDeclFromEachModule(const Decl *D,
-                                                  bool IncludeLocal) {
-      assert(isa<ClassTemplateSpecializationDecl>(D) ||
-             isa<VarTemplateSpecializationDecl>(D) ||
-             isa<FunctionDecl>(D) && "Must not be called with other decls");
+    void AddFirstSpecializationDeclFromEachModule(
+        const Decl *D, llvm::SmallVectorImpl<const Decl *> &SpecsInMap) {
+      assert((isa<ClassTemplateSpecializationDecl>(D) ||
+              isa<VarTemplateSpecializationDecl>(D) || isa<FunctionDecl>(D)) &&
+             "Must not be called with other decls");
       llvm::MapVector<ModuleFile *, const Decl *> Firsts;
-      CollectFirstDeclFromEachModule(D, IncludeLocal, Firsts);
+      CollectFirstDeclFromEachModule(D, /*IncludeLocal*/ true, Firsts);
 
       for (const auto &F : Firsts) {
+        SpecsInMap.push_back(F.second);
+
         Record.AddDeclRef(F.second);
         ArrayRef<TemplateArgument> Args;
         if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(D))
@@ -282,10 +284,15 @@ namespace clang {
       for (auto &Entry : getPartialSpecializations(Common))
         Specs.push_back(getSpecializationDecl(Entry));
 
+      llvm::SmallVector<const Decl *, 16> SpecsInOnDiskMap = Specs;
+
       for (auto *D : Specs) {
         assert(D->isCanonicalDecl() && "non-canonical decl in set");
-        AddFirstSpecializationDeclFromEachModule(D, /*IncludeLocal*/ true);
+        AddFirstSpecializationDeclFromEachModule(D, SpecsInOnDiskMap);
       }
+
+      // We don't need to insert LazySpecializations to SpecsInOnDiskMap,
+      // since we'll handle that in GenerateSpecializationInfoLookupTable.
       for (auto &SpecInfo : LazySpecializations) {
         Record.push_back(SpecInfo.DeclID.getRawValue());
         Record.push_back(SpecInfo.ODRHash);
@@ -298,6 +305,9 @@ namespace clang {
       assert((Record.size() - I - 1) % 3 == 0 &&
              "Must be divisible by LazySpecializationInfo count!");
       Record[I] = (Record.size() - I - 1) / 3;
+
+      Record.AddOffset(
+          Writer.WriteSpecializationInfoLookupTable(D, SpecsInOnDiskMap));
     }
 
     /// Ensure that this template specialization is associated with the specified
@@ -320,6 +330,9 @@ namespace clang {
 
       Writer.DeclUpdates[Template].push_back(ASTWriter::DeclUpdate(
           UPD_CXX_ADDED_TEMPLATE_SPECIALIZATION, Specialization));
+
+      Writer.SpecializationsUpdates[cast<NamedDecl>(Template)].push_back(
+          cast<NamedDecl>(Specialization));
     }
   };
 }
@@ -2824,6 +2837,11 @@ void ASTWriter::WriteDeclAbbrevs() {
   Abv->Add(BitCodeAbbrevOp(serialization::DECL_CONTEXT_VISIBLE));
   Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Blob));
   DeclContextVisibleLookupAbbrev = Stream.EmitAbbrev(std::move(Abv));
+
+  Abv = std::make_shared<BitCodeAbbrev>();
+  Abv->Add(BitCodeAbbrevOp(serialization::DECL_SPECIALIZATIONS));
+  Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Blob));
+  DeclSpecializationsAbbrev = Stream.EmitAbbrev(std::move(Abv));
 }
 
 /// isRequiredDecl - Check if this is a "required" Decl, which must be seen by
diff --git a/clang/unittests/Serialization/CMakeLists.txt b/clang/unittests/Serialization/CMakeLists.txt
index e7eebd0cb98239..e7005b5d511eba 100644
--- a/clang/unittests/Serialization/CMakeLists.txt
+++ b/clang/unittests/Serialization/CMakeLists.txt
@@ -11,6 +11,7 @@ add_clang_unittest(SerializationTests
   ModuleCacheTest.cpp
   NoCommentsTest.cpp
   PreambleInNamedModulesTest.cpp
+  LoadSpecLazilyTest.cpp
   SourceLocationEncodingTest.cpp
   VarDeclConstantInitTest.cpp
   )
diff --git a/clang/unittests/Serialization/LoadSpecLazilyTest.cpp b/clang/unittests/Serialization/LoadSpecLazilyTest.cpp
new file mode 100644
index 00000000000000..76e3ccae3d3c3d
--- /dev/null
+++ b/clang/unittests/Serialization/LoadSpecLazilyTest.cpp
@@ -0,0 +1,260 @@
+//== unittests/Serialization/LoadSpecLazily.cpp ----------------------========//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Frontend/FrontendAction.h"
+#include "clang/Frontend/FrontendActions.h"
+#include "clang/Parse/ParseAST.h"
+#include "clang/Serialization/ASTDeserializationListener.h"
+#include "clang/Tooling/Tooling.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using namespace clang;
+using namespace clang::tooling;
+
+namespace {
+
+class LoadSpecLazilyTest : public ::testing::Test {
+  void SetUp() override {
+    ASSERT_FALSE(
+        sys::fs::createUniqueDirectory("load-spec-lazily-test", TestDir));
+  }
+
+  void TearDown() override { sys::fs::remove_directories(TestDir); }
+
+public:
+  SmallString<256> TestDir;
+
+  void addFile(StringRef Path, StringRef Contents) {
+    ASSERT_FALSE(sys::path::is_absolute(Path));
+
+    SmallString<256> AbsPath(TestDir);
+    sys::path::append(AbsPath, Path);
+
+    ASSERT_FALSE(
+        sys::fs::create_directories(llvm::sys::path::parent_path(AbsPath)));
+
+    std::error_code EC;
+    llvm::raw_fd_ostream OS(AbsPath, EC);
+    ASSERT_FALSE(EC);
+    OS << Contents;
+  }
+
+  std::string GenerateModuleInterface(StringRef ModuleName,
+                                      StringRef Contents) {
+    std::string FileName = llvm::Twine(ModuleName + ".cppm").str();
+    addFile(FileName, Contents);
+
+    IntrusiveRefCntPtr<DiagnosticsEngine> Diags =
+        CompilerInstance::createDiagnostics(new DiagnosticOptions());
+    CreateInvocationOptions CIOpts;
+    CIOpts.Diags = Diags;
+    CIOpts.VFS = llvm::vfs::createPhysicalFileSystem();
+
+    std::string CacheBMIPath =
+        llvm::Twine(TestDir + "/" + ModuleName + ".pcm").str();
+    std::string PrebuiltModulePath =
+        "-fprebuilt-module-path=" + TestDir.str().str();
+    const char *Args[] = {"clang++",
+                          "-std=c++20",
+                          "--precompile",
+                          PrebuiltModulePath.c_str(),
+                          "-working-directory",
+                          TestDir.c_str(),
+                          "-I",
+                          TestDir.c_str(),
+                          FileName.c_str(),
+                          "-o",
+                          CacheBMIPath.c_str()};
+    std::shared_ptr<CompilerInvocation> Invocation =
+        createInvocation(Args, CIOpts);
+    EXPECT_TRUE(Invocation);
+
+    CompilerInstance Instance;
+    Instance.setDiagnostics(Diags.get());
+    Instance.setInvocation(Invocation);
+    Instance.getFrontendOpts().OutputFile = CacheBMIPath;
+    GenerateModuleInterfaceAction Action;
+    EXPECT_TRUE(Instance.ExecuteAction(Action));
+    EXPECT_FALSE(Diags->hasErrorOccurred());
+
+    return CacheBMIPath;
+  }
+};
+
+enum class CheckingMode { Forbidden, Required };
+
+class DeclsReaderListener : public ASTDeserializationListener {
+  StringRef SpeficiedName;
+  CheckingMode Mode;
+
+  bool ReadedSpecifiedName = false;
+
+public:
+  void DeclRead(GlobalDeclID ID, const Decl *D) override {
+    auto *ND = dyn_cast<NamedDecl>(D);
+    if (!ND)
+      return;
+
+    ReadedSpecifiedName |= ND->getName().contains(SpeficiedName);
+    if (Mode == CheckingMode::Forbidden) {
+      EXPECT_FALSE(ReadedSpecifiedName);
+    }
+  }
+
+  DeclsReaderListener(StringRef SpeficiedName, CheckingMode Mode)
+      : SpeficiedName(SpeficiedName), Mode(Mode) {}
+
+  ~DeclsReaderListener() {
+    if (Mode == CheckingMode::Required) {
+      EXPECT_TRUE(ReadedSpecifiedName);
+    }
+  }
+};
+
+class LoadSpecLazilyConsumer : public ASTConsumer {
+  DeclsReaderListener Listener;
+
+public:
+  LoadSpecLazilyConsumer(StringRef SpecifiedName, CheckingMode Mode)
+      : Listener(SpecifiedName, Mode) {}
+
+  ASTDeserializationListener *GetASTDeserializationListener() override {
+    return &Listener;
+  }
+};
+
+class CheckLoadSpecLazilyAction : public ASTFrontendAction {
+  StringRef SpecifiedName;
+  CheckingMode Mode;
+
+public:
+  std::unique_ptr<ASTConsumer>
+  CreateASTConsumer(CompilerInstance &CI, StringRef /*Unused*/) override {
+    return std::make_unique<LoadSpecLazilyConsumer>(SpecifiedName, Mode);
+  }
+
+  CheckLoadSpecLazilyAction(StringRef SpecifiedName, CheckingMode Mode)
+      : SpecifiedName(SpecifiedName), Mode(Mode) {}
+};
+
+TEST_F(LoadSpecLazilyTest, BasicTest) {
+  GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+export class ShouldNotBeLoaded {};
+export class Temp {
+   A<ShouldNotBeLoaded> AS;
+};
+  )cpp");
+
+  const char *test_file_contents = R"cpp(
+import M;
+A<int> a;
+  )cpp";
+  std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+  EXPECT_TRUE(
+      runToolOnCodeWithArgs(std::make_unique<CheckLoadSpecLazilyAction>(
+                                "ShouldNotBeLoaded", CheckingMode::Forbidden),
+                            test_file_contents,
+                            {
+                                "-std=c++20",
+                                DepArg.c_str(),
+                                "-I",
+                                TestDir.c_str(),
+                            },
+                            "test.cpp"));
+}
+
+TEST_F(LoadSpecLazilyTest, ChainedTest) {
+  GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+  )cpp");
+
+  GenerateModuleInterface("N", R"cpp(
+export module N;
+export import M;
+export class ShouldNotBeLoaded {};
+export class Temp {
+   A<ShouldNotBeLoaded> AS;
+};
+  )cpp");
+
+  const char *test_file_contents = R"cpp(
+import N;
+A<int> a;
+  )cpp";
+  std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+  EXPECT_TRUE(
+      runToolOnCodeWithArgs(std::make_unique<CheckLoadSpecLazilyAction>(
+                                "ShouldNotBeLoaded", CheckingMode::Forbidden),
+                            test_file_contents,
+                            {
+                                "-std=c++20",
+                                DepArg.c_str(),
+                                "-I",
+                                TestDir.c_str(),
+                            },
+                            "test.cpp"));
+}
+
+/// Test that we won't crash due to we may invalidate the lazy specialization
+/// lookup table during the loading process.
+TEST_F(LoadSpecLazilyTest, ChainedTest2) {
+  GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+
+export class B {};
+
+export class C {
+  A<B> D;
+};
+  )cpp");
+
+  GenerateModuleInterface("N", R"cpp(
+export module N;
+export import M;
+export class MayBeLoaded {};
+
+export class Temp {
+   A<MayBeLoaded> AS;
+};
+
+export class ExportedClass {};
+
+export template<> class A<ExportedClass> {
+   A<MayBeLoaded> AS;
+   A<B>           AB;
+};
+  )cpp");
+
+  const char *test_file_contents = R"cpp(
+import N;
+Temp T;
+A<ExportedClass> a;
+  )cpp";
+  std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+  EXPECT_TRUE(runToolOnCodeWithArgs(std::make_unique<CheckLoadSpecLazilyAction>(
+                                        "MayBeLoaded", CheckingMode::Required),
+                                    test_file_contents,
+                                    {
+                                        "-std=c++20",
+                                        DepArg.c_str(),
+                                        "-I",
+                                        TestDir.c_str(),
+                                    },
+                                    "test.cpp"));
+}
+
+} // namespace



More information about the llvm-branch-commits mailing list