[clang] [Serialization] Load Specializations Lazily (1/2) (PR #76774)

Chuanqi Xu via cfe-commits cfe-commits at lists.llvm.org
Mon Jan 8 22:57:00 PST 2024


https://github.com/ChuanqiXu9 updated https://github.com/llvm/llvm-project/pull/76774

>From 50fd47f2bfda527807f8cc5e46425050246868aa Mon Sep 17 00:00:00 2001
From: Chuanqi Xu <yedeng.yd at linux.alibaba.com>
Date: Wed, 3 Jan 2024 11:33:17 +0800
Subject: [PATCH] Load Specializations Lazily

---
 clang/include/clang/AST/DeclTemplate.h        |  35 ++--
 clang/include/clang/AST/ExternalASTSource.h   |   6 +
 clang/include/clang/AST/ODRHash.h             |   3 +
 .../clang/Sema/MultiplexExternalSemaSource.h  |   6 +
 .../include/clang/Serialization/ASTBitCodes.h |   3 +
 clang/include/clang/Serialization/ASTReader.h |  28 +++
 clang/include/clang/Serialization/ASTWriter.h |   6 +
 clang/lib/AST/DeclTemplate.cpp                |  67 +++++---
 clang/lib/AST/ExternalASTSource.cpp           |   5 +
 clang/lib/AST/ODRHash.cpp                     |   2 +
 .../lib/Sema/MultiplexExternalSemaSource.cpp  |   6 +
 clang/lib/Serialization/ASTReader.cpp         | 125 +++++++++++++-
 clang/lib/Serialization/ASTReaderDecl.cpp     |  35 +++-
 clang/lib/Serialization/ASTReaderInternals.h  |  80 +++++++++
 clang/lib/Serialization/ASTWriter.cpp         | 144 +++++++++++++++-
 clang/lib/Serialization/ASTWriterDecl.cpp     |  78 ++++++---
 clang/test/Modules/odr_hash.cpp               |   4 +-
 clang/unittests/Serialization/CMakeLists.txt  |   1 +
 .../Serialization/LoadSpecLazily.cpp          | 159 ++++++++++++++++++
 19 files changed, 728 insertions(+), 65 deletions(-)
 create mode 100644 clang/unittests/Serialization/LoadSpecLazily.cpp

diff --git a/clang/include/clang/AST/DeclTemplate.h b/clang/include/clang/AST/DeclTemplate.h
index 832ad2de6b08a8..4699dd17bc182c 100644
--- a/clang/include/clang/AST/DeclTemplate.h
+++ b/clang/include/clang/AST/DeclTemplate.h
@@ -525,8 +525,11 @@ class FunctionTemplateSpecializationInfo final
     return Function.getInt();
   }
 
+  void loadExternalRedecls();
+
 public:
   friend TrailingObjects;
+  friend class ASTReader;
 
   static FunctionTemplateSpecializationInfo *
   Create(ASTContext &C, FunctionDecl *FD, FunctionTemplateDecl *Template,
@@ -789,13 +792,15 @@ class RedeclarableTemplateDecl : public TemplateDecl,
     return SpecIterator<EntryType>(isEnd ? Specs.end() : Specs.begin());
   }
 
-  void loadLazySpecializationsImpl() const;
+  void loadExternalSpecializations() const;
 
   template <class EntryType, typename ...ProfileArguments>
   typename SpecEntryTraits<EntryType>::DeclType*
   findSpecializationImpl(llvm::FoldingSetVector<EntryType> &Specs,
                          void *&InsertPos, ProfileArguments &&...ProfileArgs);
 
+  void loadLazySpecializationsWithArgs(ArrayRef<TemplateArgument> TemplateArgs);
+
   template <class Derived, class EntryType>
   void addSpecializationImpl(llvm::FoldingSetVector<EntryType> &Specs,
                              EntryType *Entry, void *InsertPos);
@@ -814,9 +819,13 @@ class RedeclarableTemplateDecl : public TemplateDecl,
     /// If non-null, points to an array of specializations (including
     /// partial specializations) known only by their external declaration IDs.
     ///
+    /// These specializations needs to be loaded at once in
+    /// loadExternalSpecializations to complete the redecl chain or be preparing
+    /// for template resolution.
+    ///
     /// The first value in the array is the number of specializations/partial
     /// specializations that follow.
-    uint32_t *LazySpecializations = nullptr;
+    uint32_t *ExternalSpecializations = nullptr;
 
     /// The set of "injected" template arguments used within this
     /// template.
@@ -850,6 +859,8 @@ class RedeclarableTemplateDecl : public TemplateDecl,
   friend class ASTDeclWriter;
   friend class ASTReader;
   template <class decl_type> friend class RedeclarableTemplate;
+  friend class ClassTemplateSpecializationDecl;
+  friend class VarTemplateSpecializationDecl;
 
   /// Retrieves the canonical declaration of this template.
   RedeclarableTemplateDecl *getCanonicalDecl() override {
@@ -977,6 +988,7 @@ SpecEntryTraits<FunctionTemplateSpecializationInfo> {
 class FunctionTemplateDecl : public RedeclarableTemplateDecl {
 protected:
   friend class FunctionDecl;
+  friend class FunctionTemplateSpecializationInfo;
 
   /// Data that is common to all of the declarations of a given
   /// function template.
@@ -1012,13 +1024,13 @@ class FunctionTemplateDecl : public RedeclarableTemplateDecl {
   void addSpecialization(FunctionTemplateSpecializationInfo* Info,
                          void *InsertPos);
 
+  /// Load any lazily-loaded specializations from the external source.
+  void LoadLazySpecializations() const;
+
 public:
   friend class ASTDeclReader;
   friend class ASTDeclWriter;
 
-  /// Load any lazily-loaded specializations from the external source.
-  void LoadLazySpecializations() const;
-
   /// Get the underlying function declaration of the template.
   FunctionDecl *getTemplatedDecl() const {
     return static_cast<FunctionDecl *>(TemplatedDecl);
@@ -1839,6 +1851,8 @@ class ClassTemplateSpecializationDecl
   LLVM_PREFERRED_TYPE(TemplateSpecializationKind)
   unsigned SpecializationKind : 3;
 
+  void loadExternalRedecls();
+
 protected:
   ClassTemplateSpecializationDecl(ASTContext &Context, Kind DK, TagKind TK,
                                   DeclContext *DC, SourceLocation StartLoc,
@@ -1852,6 +1866,7 @@ class ClassTemplateSpecializationDecl
 public:
   friend class ASTDeclReader;
   friend class ASTDeclWriter;
+  friend class ASTReader;
 
   static ClassTemplateSpecializationDecl *
   Create(ASTContext &Context, TagKind TK, DeclContext *DC,
@@ -2285,9 +2300,7 @@ class ClassTemplateDecl : public RedeclarableTemplateDecl {
   friend class ASTDeclReader;
   friend class ASTDeclWriter;
   friend class TemplateDeclInstantiator;
-
-  /// Load any lazily-loaded specializations from the external source.
-  void LoadLazySpecializations() const;
+  friend class ClassTemplateSpecializationDecl;
 
   /// Get the underlying class declarations of the template.
   CXXRecordDecl *getTemplatedDecl() const {
@@ -2651,6 +2664,8 @@ class VarTemplateSpecializationDecl : public VarDecl,
   LLVM_PREFERRED_TYPE(bool)
   unsigned IsCompleteDefinition : 1;
 
+  void loadExternalRedecls();
+
 protected:
   VarTemplateSpecializationDecl(Kind DK, ASTContext &Context, DeclContext *DC,
                                 SourceLocation StartLoc, SourceLocation IdLoc,
@@ -2664,6 +2679,7 @@ class VarTemplateSpecializationDecl : public VarDecl,
 public:
   friend class ASTDeclReader;
   friend class ASTDeclWriter;
+  friend class ASTReader;
   friend class VarDecl;
 
   static VarTemplateSpecializationDecl *
@@ -3057,8 +3073,7 @@ class VarTemplateDecl : public RedeclarableTemplateDecl {
   friend class ASTDeclReader;
   friend class ASTDeclWriter;
 
-  /// Load any lazily-loaded specializations from the external source.
-  void LoadLazySpecializations() const;
+  friend class VarTemplatePartialSpecializationDecl;
 
   /// Get the underlying variable declarations of the template.
   VarDecl *getTemplatedDecl() const {
diff --git a/clang/include/clang/AST/ExternalASTSource.h b/clang/include/clang/AST/ExternalASTSource.h
index 8e573965b0a336..75ee6f827080d4 100644
--- a/clang/include/clang/AST/ExternalASTSource.h
+++ b/clang/include/clang/AST/ExternalASTSource.h
@@ -150,6 +150,12 @@ class ExternalASTSource : public RefCountedBase<ExternalASTSource> {
   virtual bool
   FindExternalVisibleDeclsByName(const DeclContext *DC, DeclarationName Name);
 
+  /// Load all the external specialzations for the Decl and the corresponding
+  /// template arguments.
+  virtual void
+  LoadExternalSpecializations(const Decl *D,
+                              ArrayRef<TemplateArgument> TemplateArgs);
+
   /// Ensures that the table of all visible declarations inside this
   /// context is up to date.
   ///
diff --git a/clang/include/clang/AST/ODRHash.h b/clang/include/clang/AST/ODRHash.h
index cedf644520fc32..ddd1bb0f095e75 100644
--- a/clang/include/clang/AST/ODRHash.h
+++ b/clang/include/clang/AST/ODRHash.h
@@ -101,6 +101,9 @@ class ODRHash {
   // Save booleans until the end to lower the size of data to process.
   void AddBoolean(bool value);
 
+  // Add intergers to ID.
+  void AddInteger(unsigned Value);
+
   static bool isSubDeclToBeProcessed(const Decl *D, const DeclContext *Parent);
 
 private:
diff --git a/clang/include/clang/Sema/MultiplexExternalSemaSource.h b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
index 2bf91cb5212c5e..9345e96e8cb8c3 100644
--- a/clang/include/clang/Sema/MultiplexExternalSemaSource.h
+++ b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
@@ -97,6 +97,12 @@ class MultiplexExternalSemaSource : public ExternalSemaSource {
   bool FindExternalVisibleDeclsByName(const DeclContext *DC,
                                       DeclarationName Name) override;
 
+  /// Load all the external specialzations for the Decl and the corresponding
+  /// template args.
+  virtual void
+  LoadExternalSpecializations(const Decl *D,
+                              ArrayRef<TemplateArgument> TemplateArgs) override;
+
   /// Ensures that the table of all visible declarations inside this
   /// context is up to date.
   void completeVisibleDeclsMap(const DeclContext *DC) override;
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index fdd64f2abbe937..23a279de96ab15 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1523,6 +1523,9 @@ enum DeclCode {
   /// An ImplicitConceptSpecializationDecl record.
   DECL_IMPLICIT_CONCEPT_SPECIALIZATION,
 
+  // A decls specilization record.
+  DECL_SPECIALIZATIONS,
+
   DECL_LAST = DECL_IMPLICIT_CONCEPT_SPECIALIZATION
 };
 
diff --git a/clang/include/clang/Serialization/ASTReader.h b/clang/include/clang/Serialization/ASTReader.h
index 21d791f5cd89a2..293d6495d164ef 100644
--- a/clang/include/clang/Serialization/ASTReader.h
+++ b/clang/include/clang/Serialization/ASTReader.h
@@ -340,6 +340,9 @@ class ASTIdentifierLookupTrait;
 /// The on-disk hash table(s) used for DeclContext name lookup.
 struct DeclContextLookupTable;
 
+/// The on-disk hash table(s) used for specialization decls.
+struct SpecializationsLookupTable;
+
 } // namespace reader
 
 } // namespace serialization
@@ -599,6 +602,11 @@ class ASTReader
   llvm::DenseMap<const DeclContext *,
                  serialization::reader::DeclContextLookupTable> Lookups;
 
+  /// Map from decls to specialized decls.
+  llvm::DenseMap<const Decl *,
+                 serialization::reader::SpecializationsLookupTable>
+      SpecializationsLookups;
+
   // Updates for visible decls can occur for other contexts than just the
   // TU, and when we read those update records, the actual context may not
   // be available yet, so have this pending map using the ID as a key. It
@@ -640,6 +648,9 @@ class ASTReader
                                      llvm::BitstreamCursor &Cursor,
                                      uint64_t Offset, serialization::DeclID ID);
 
+  bool ReadSpecializations(ModuleFile &M, llvm::BitstreamCursor &Cursor,
+                           uint64_t Offset, Decl *D);
+
   /// A vector containing identifiers that have already been
   /// loaded.
   ///
@@ -1343,6 +1354,11 @@ class ASTReader
   const serialization::reader::DeclContextLookupTable *
   getLoadedLookupTables(DeclContext *Primary) const;
 
+  /// Get the loaded specializations lookup tables for \p D,
+  /// if any.
+  serialization::reader::SpecializationsLookupTable *
+  getLoadedSpecializationsLookupTables(const Decl *D);
+
 private:
   struct ImportedModule {
     ModuleFile *Mod;
@@ -1982,6 +1998,10 @@ class ASTReader
   bool FindExternalVisibleDeclsByName(const DeclContext *DC,
                                       DeclarationName Name) override;
 
+  void
+  LoadExternalSpecializations(const Decl *D,
+                              ArrayRef<TemplateArgument> TemplateArgs) override;
+
   /// Read all of the declarations lexically stored in a
   /// declaration context.
   ///
@@ -2404,6 +2424,14 @@ class ASTReader
   bool isProcessingUpdateRecords() { return ProcessingUpdateRecords; }
 };
 
+/// Get a stable hash for template arguments across compiler invovations.
+/// This is used for loading corresponding specializations lazily according
+/// to the template arguments.
+/// Given it is fine to load additional specializations, we're tolerant to
+/// map different template arguments to the same hash value as long as the
+/// semantically same template arguments get the same hash value.
+unsigned GetTemplateArgsStableHash(ArrayRef<TemplateArgument> TemplateArgs);
+
 /// A simple helper class to unpack an integer to bits and consuming
 /// the bits in order.
 class BitsUnpacker {
diff --git a/clang/include/clang/Serialization/ASTWriter.h b/clang/include/clang/Serialization/ASTWriter.h
index de69f99003d827..09806b87590766 100644
--- a/clang/include/clang/Serialization/ASTWriter.h
+++ b/clang/include/clang/Serialization/ASTWriter.h
@@ -527,6 +527,10 @@ class ASTWriter : public ASTDeserializationListener,
   bool isLookupResultExternal(StoredDeclsList &Result, DeclContext *DC);
   bool isLookupResultEntirelyExternal(StoredDeclsList &Result, DeclContext *DC);
 
+  uint64_t WriteSpecializationsLookupTable(
+      const NamedDecl *D,
+      llvm::SmallVectorImpl<const NamedDecl *> &Specializations);
+
   void GenerateNameLookupTable(const DeclContext *DC,
                                llvm::SmallVectorImpl<char> &LookupTable);
   uint64_t WriteDeclContextLexicalBlock(ASTContext &Context, DeclContext *DC);
@@ -564,6 +568,8 @@ class ASTWriter : public ASTDeserializationListener,
   unsigned DeclEnumAbbrev = 0;
   unsigned DeclObjCIvarAbbrev = 0;
   unsigned DeclCXXMethodAbbrev = 0;
+  unsigned DeclSpecializationsAbbrev = 0;
+
   unsigned DeclDependentNonTemplateCXXMethodAbbrev = 0;
   unsigned DeclTemplateCXXMethodAbbrev = 0;
   unsigned DeclMemberSpecializedCXXMethodAbbrev = 0;
diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp
index 7d7556e670f951..16b396a5e785d7 100644
--- a/clang/lib/AST/DeclTemplate.cpp
+++ b/clang/lib/AST/DeclTemplate.cpp
@@ -331,14 +331,14 @@ RedeclarableTemplateDecl::CommonBase *RedeclarableTemplateDecl::getCommonPtr() c
   return Common;
 }
 
-void RedeclarableTemplateDecl::loadLazySpecializationsImpl() const {
+void RedeclarableTemplateDecl::loadExternalSpecializations() const {
   // Grab the most recent declaration to ensure we've loaded any lazy
   // redeclarations of this template.
   CommonBase *CommonBasePtr = getMostRecentDecl()->getCommonPtr();
-  if (CommonBasePtr->LazySpecializations) {
+  if (CommonBasePtr->ExternalSpecializations) {
     ASTContext &Context = getASTContext();
-    uint32_t *Specs = CommonBasePtr->LazySpecializations;
-    CommonBasePtr->LazySpecializations = nullptr;
+    uint32_t *Specs = CommonBasePtr->ExternalSpecializations;
+    CommonBasePtr->ExternalSpecializations = nullptr;
     for (uint32_t I = 0, N = *Specs++; I != N; ++I)
       (void)Context.getExternalSource()->GetExternalDecl(Specs[I]);
   }
@@ -358,6 +358,16 @@ RedeclarableTemplateDecl::findSpecializationImpl(
   return Entry ? SETraits::getDecl(Entry)->getMostRecentDecl() : nullptr;
 }
 
+void RedeclarableTemplateDecl::loadLazySpecializationsWithArgs(
+    ArrayRef<TemplateArgument> TemplateArgs) {
+  auto *ExternalSource = getASTContext().getExternalSource();
+  if (!ExternalSource)
+    return;
+
+  ExternalSource->LoadExternalSpecializations(this->getCanonicalDecl(),
+                                              TemplateArgs);
+}
+
 template<class Derived, class EntryType>
 void RedeclarableTemplateDecl::addSpecializationImpl(
     llvm::FoldingSetVector<EntryType> &Specializations, EntryType *Entry,
@@ -430,24 +440,23 @@ FunctionTemplateDecl::newCommon(ASTContext &C) const {
   return CommonPtr;
 }
 
-void FunctionTemplateDecl::LoadLazySpecializations() const {
-  loadLazySpecializationsImpl();
-}
-
 llvm::FoldingSetVector<FunctionTemplateSpecializationInfo> &
 FunctionTemplateDecl::getSpecializations() const {
-  LoadLazySpecializations();
+  loadExternalSpecializations();
   return getCommonPtr()->Specializations;
 }
 
 FunctionDecl *
 FunctionTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                          void *&InsertPos) {
+  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getSpecializations(), InsertPos, Args);
 }
 
 void FunctionTemplateDecl::addSpecialization(
       FunctionTemplateSpecializationInfo *Info, void *InsertPos) {
+  using SETraits = SpecEntryTraits<FunctionTemplateSpecializationInfo>;
+  loadLazySpecializationsWithArgs(SETraits::getTemplateArgs(Info));
   addSpecializationImpl<FunctionTemplateDecl>(getSpecializations(), Info,
                                               InsertPos);
 }
@@ -508,19 +517,15 @@ ClassTemplateDecl *ClassTemplateDecl::CreateDeserialized(ASTContext &C,
                                        DeclarationName(), nullptr, nullptr);
 }
 
-void ClassTemplateDecl::LoadLazySpecializations() const {
-  loadLazySpecializationsImpl();
-}
-
 llvm::FoldingSetVector<ClassTemplateSpecializationDecl> &
 ClassTemplateDecl::getSpecializations() const {
-  LoadLazySpecializations();
+  loadExternalSpecializations();
   return getCommonPtr()->Specializations;
 }
 
 llvm::FoldingSetVector<ClassTemplatePartialSpecializationDecl> &
 ClassTemplateDecl::getPartialSpecializations() const {
-  LoadLazySpecializations();
+  loadExternalSpecializations();
   return getCommonPtr()->PartialSpecializations;
 }
 
@@ -534,11 +539,14 @@ ClassTemplateDecl::newCommon(ASTContext &C) const {
 ClassTemplateSpecializationDecl *
 ClassTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                       void *&InsertPos) {
+  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getSpecializations(), InsertPos, Args);
 }
 
 void ClassTemplateDecl::AddSpecialization(ClassTemplateSpecializationDecl *D,
                                           void *InsertPos) {
+  using SETraits = SpecEntryTraits<ClassTemplateSpecializationDecl>;
+  loadLazySpecializationsWithArgs(SETraits::getTemplateArgs(D));
   addSpecializationImpl<ClassTemplateDecl>(getSpecializations(), D, InsertPos);
 }
 
@@ -546,6 +554,7 @@ ClassTemplatePartialSpecializationDecl *
 ClassTemplateDecl::findPartialSpecialization(
     ArrayRef<TemplateArgument> Args,
     TemplateParameterList *TPL, void *&InsertPos) {
+  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getPartialSpecializations(), InsertPos, Args,
                                 TPL);
 }
@@ -900,6 +909,11 @@ FunctionTemplateSpecializationInfo *FunctionTemplateSpecializationInfo::Create(
       FD, Template, TSK, TemplateArgs, ArgsAsWritten, POI, MSInfo);
 }
 
+void FunctionTemplateSpecializationInfo::loadExternalRedecls() {
+  getTemplate()->loadExternalSpecializations();
+  getTemplate()->loadLazySpecializationsWithArgs(TemplateArguments->asArray());
+}
+
 //===----------------------------------------------------------------------===//
 // ClassTemplateSpecializationDecl Implementation
 //===----------------------------------------------------------------------===//
@@ -1024,6 +1038,12 @@ ClassTemplateSpecializationDecl::getSourceRange() const {
   }
 }
 
+void ClassTemplateSpecializationDecl::loadExternalRedecls() {
+  getSpecializedTemplate()->loadExternalSpecializations();
+  getSpecializedTemplate()->loadLazySpecializationsWithArgs(
+      getTemplateArgs().asArray());
+}
+
 //===----------------------------------------------------------------------===//
 // ConceptDecl Implementation
 //===----------------------------------------------------------------------===//
@@ -1226,19 +1246,15 @@ VarTemplateDecl *VarTemplateDecl::CreateDeserialized(ASTContext &C,
                                      DeclarationName(), nullptr, nullptr);
 }
 
-void VarTemplateDecl::LoadLazySpecializations() const {
-  loadLazySpecializationsImpl();
-}
-
 llvm::FoldingSetVector<VarTemplateSpecializationDecl> &
 VarTemplateDecl::getSpecializations() const {
-  LoadLazySpecializations();
+  loadExternalSpecializations();
   return getCommonPtr()->Specializations;
 }
 
 llvm::FoldingSetVector<VarTemplatePartialSpecializationDecl> &
 VarTemplateDecl::getPartialSpecializations() const {
-  LoadLazySpecializations();
+  loadExternalSpecializations();
   return getCommonPtr()->PartialSpecializations;
 }
 
@@ -1252,17 +1268,21 @@ VarTemplateDecl::newCommon(ASTContext &C) const {
 VarTemplateSpecializationDecl *
 VarTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                     void *&InsertPos) {
+  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getSpecializations(), InsertPos, Args);
 }
 
 void VarTemplateDecl::AddSpecialization(VarTemplateSpecializationDecl *D,
                                         void *InsertPos) {
+  using SETraits = SpecEntryTraits<VarTemplateSpecializationDecl>;
+  loadLazySpecializationsWithArgs(SETraits::getTemplateArgs(D));
   addSpecializationImpl<VarTemplateDecl>(getSpecializations(), D, InsertPos);
 }
 
 VarTemplatePartialSpecializationDecl *
 VarTemplateDecl::findPartialSpecialization(ArrayRef<TemplateArgument> Args,
      TemplateParameterList *TPL, void *&InsertPos) {
+  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getPartialSpecializations(), InsertPos, Args,
                                 TPL);
 }
@@ -1393,6 +1413,11 @@ SourceRange VarTemplateSpecializationDecl::getSourceRange() const {
   return VarDecl::getSourceRange();
 }
 
+void VarTemplateSpecializationDecl::loadExternalRedecls() {
+  getSpecializedTemplate()->loadExternalSpecializations();
+  getSpecializedTemplate()->loadLazySpecializationsWithArgs(
+      getTemplateArgs().asArray());
+}
 
 //===----------------------------------------------------------------------===//
 // VarTemplatePartialSpecializationDecl Implementation
diff --git a/clang/lib/AST/ExternalASTSource.cpp b/clang/lib/AST/ExternalASTSource.cpp
index 090ef02aa4224d..bba73fcbd8fa23 100644
--- a/clang/lib/AST/ExternalASTSource.cpp
+++ b/clang/lib/AST/ExternalASTSource.cpp
@@ -100,6 +100,11 @@ ExternalASTSource::FindExternalVisibleDeclsByName(const DeclContext *DC,
   return false;
 }
 
+void ExternalASTSource::LoadExternalSpecializations(
+    const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+  return;
+}
+
 void ExternalASTSource::completeVisibleDeclsMap(const DeclContext *DC) {}
 
 void ExternalASTSource::FindExternalLexicalDecls(
diff --git a/clang/lib/AST/ODRHash.cpp b/clang/lib/AST/ODRHash.cpp
index aea1a93ae1fa82..ace24eb4d29d85 100644
--- a/clang/lib/AST/ODRHash.cpp
+++ b/clang/lib/AST/ODRHash.cpp
@@ -1249,3 +1249,5 @@ void ODRHash::AddQualType(QualType T) {
 void ODRHash::AddBoolean(bool Value) {
   Bools.push_back(Value);
 }
+
+void ODRHash::AddInteger(unsigned Value) { ID.AddInteger(Value); }
diff --git a/clang/lib/Sema/MultiplexExternalSemaSource.cpp b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
index 058e22cb2b814e..b845383675ab36 100644
--- a/clang/lib/Sema/MultiplexExternalSemaSource.cpp
+++ b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
@@ -115,6 +115,12 @@ FindExternalVisibleDeclsByName(const DeclContext *DC, DeclarationName Name) {
   return AnyDeclsFound;
 }
 
+void MultiplexExternalSemaSource::LoadExternalSpecializations(
+    const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+  for (size_t i = 0; i < Sources.size(); ++i)
+    Sources[i]->LoadExternalSpecializations(D, TemplateArgs);
+}
+
 void MultiplexExternalSemaSource::completeVisibleDeclsMap(const DeclContext *DC){
   for(size_t i = 0; i < Sources.size(); ++i)
     Sources[i]->completeVisibleDeclsMap(DC);
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 9effd333daccdb..bcdd2dfc491d8f 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -1223,6 +1223,38 @@ void ASTDeclContextNameLookupTrait::ReadDataInto(internal_key_type,
   }
 }
 
+ModuleFile *SpecializationsLookupTrait::ReadFileRef(const unsigned char *&d) {
+  using namespace llvm::support;
+
+  uint32_t ModuleFileID =
+      endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+  return Reader.getLocalModuleFile(F, ModuleFileID);
+}
+
+SpecializationsLookupTrait::internal_key_type
+SpecializationsLookupTrait::ReadKey(const unsigned char *d, unsigned) {
+  using namespace llvm::support;
+  return endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+}
+
+std::pair<unsigned, unsigned>
+SpecializationsLookupTrait::ReadKeyDataLength(const unsigned char *&d) {
+  return readULEBKeyDataLength(d);
+}
+
+void SpecializationsLookupTrait::ReadDataInto(internal_key_type,
+                                              const unsigned char *d,
+                                              unsigned DataLen,
+                                              data_type_builder &Val) {
+  using namespace llvm::support;
+
+  for (unsigned NumDecls = DataLen / 4; NumDecls; --NumDecls) {
+    uint32_t LocalID =
+        endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+    Val.insert(Reader.getGlobalDeclID(F, LocalID));
+  }
+}
+
 bool ASTReader::ReadLexicalDeclContextStorage(ModuleFile &M,
                                               BitstreamCursor &Cursor,
                                               uint64_t Offset,
@@ -1312,6 +1344,44 @@ bool ASTReader::ReadVisibleDeclContextStorage(ModuleFile &M,
   return false;
 }
 
+bool ASTReader::ReadSpecializations(ModuleFile &M, BitstreamCursor &Cursor,
+                                    uint64_t Offset, Decl *D) {
+  assert(Offset != 0);
+
+  SavedStreamPosition SavedPosition(Cursor);
+  if (llvm::Error Err = Cursor.JumpToBit(Offset)) {
+    Error(std::move(Err));
+    return true;
+  }
+
+  RecordData Record;
+  StringRef Blob;
+  Expected<unsigned> MaybeCode = Cursor.ReadCode();
+  if (!MaybeCode) {
+    Error(MaybeCode.takeError());
+    return true;
+  }
+  unsigned Code = MaybeCode.get();
+
+  Expected<unsigned> MaybeRecCode = Cursor.readRecord(Code, Record, &Blob);
+  if (!MaybeRecCode) {
+    Error(MaybeRecCode.takeError());
+    return true;
+  }
+  unsigned RecCode = MaybeRecCode.get();
+  if (RecCode != DECL_SPECIALIZATIONS) {
+    Error("Expected decl specs block");
+    return true;
+  }
+
+  auto *Data = (const unsigned char *)Blob.data();
+  D = D->getCanonicalDecl();
+  SpecializationsLookups[D].Table.add(
+      &M, Data, reader::SpecializationsLookupTrait(*this, M));
+
+  return false;
+}
+
 void ASTReader::Error(StringRef Msg) const {
   Error(diag::err_fe_pch_malformed, Msg);
   if (PP.getLangOpts().Modules && !Diags.isDiagnosticInFlight() &&
@@ -7523,13 +7593,13 @@ void ASTReader::CompleteRedeclChain(const Decl *D) {
   }
 
   if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(D))
-    CTSD->getSpecializedTemplate()->LoadLazySpecializations();
+    const_cast<ClassTemplateSpecializationDecl *>(CTSD)->loadExternalRedecls();
   if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(D))
-    VTSD->getSpecializedTemplate()->LoadLazySpecializations();
-  if (auto *FD = dyn_cast<FunctionDecl>(D)) {
-    if (auto *Template = FD->getPrimaryTemplate())
-      Template->LoadLazySpecializations();
-  }
+    const_cast<VarTemplateSpecializationDecl *>(VTSD)->loadExternalRedecls();
+  if (auto *FD = dyn_cast<FunctionDecl>(D))
+    if (auto *FDInfo = FD->getTemplateSpecializationInfo())
+      const_cast<FunctionTemplateSpecializationInfo *>(FDInfo)
+          ->loadExternalRedecls();
 }
 
 CXXCtorInitializer **
@@ -7958,6 +8028,42 @@ ASTReader::FindExternalVisibleDeclsByName(const DeclContext *DC,
   return !Decls.empty();
 }
 
+unsigned
+clang::GetTemplateArgsStableHash(ArrayRef<TemplateArgument> TemplateArgs) {
+  // FIXME: ODR hashing may not be the best mechanism to hash the template
+  // arguments. ODR hashing is (or perhaps, should be) about determining whether
+  // two things are spelled the same way and have the same meaning (as required
+  // by the C++ ODR), whereas what we want here is whether they have the same
+  // meaning regardless of spelling. Maybe we can get away with reusing ODR
+  // hashing anyway, on the basis that any canonical, non-dependent template
+  // argument should have the same (invented) spelling in every translation
+  // unit, but it is not sure that's true in all cases. There may still be cases
+  // where the canonical type includes some aspect of "whatever we saw first",
+  // in which case the ODR hash can differ across translation units for
+  // non-dependent, canonical template arguments that are spelled differently
+  // but have the same meaning. But it is not easy to raise examples.
+  ODRHash Hasher;
+  Hasher.AddInteger(TemplateArgs.size());
+  for (const TemplateArgument &TemplateArg : TemplateArgs)
+    Hasher.AddTemplateArgument(TemplateArg);
+  return Hasher.CalculateHash();
+}
+
+void ASTReader::LoadExternalSpecializations(
+    const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+  assert(D);
+
+  auto It = SpecializationsLookups.find(D);
+  if (It == SpecializationsLookups.end())
+    return;
+
+  auto HashValue = GetTemplateArgsStableHash(TemplateArgs);
+
+  Deserializing LookupResults(this);
+  for (DeclID ID : It->second.Table.find(HashValue))
+    GetDecl(ID);
+}
+
 void ASTReader::completeVisibleDeclsMap(const DeclContext *DC) {
   if (!DC->hasExternalVisibleStorage())
     return;
@@ -7987,6 +8093,13 @@ ASTReader::getLoadedLookupTables(DeclContext *Primary) const {
   return I == Lookups.end() ? nullptr : &I->second;
 }
 
+serialization::reader::SpecializationsLookupTable *
+ASTReader::getLoadedSpecializationsLookupTables(const Decl *D) {
+  assert(D->isCanonicalDecl());
+  auto I = SpecializationsLookups.find(D);
+  return I == SpecializationsLookups.end() ? nullptr : &I->second;
+}
+
 /// Under non-PCH compilation the consumer receives the objc methods
 /// before receiving the implementation, and codegen depends on this.
 /// We simulate this by deserializing and passing to consumer the methods of the
diff --git a/clang/lib/Serialization/ASTReaderDecl.cpp b/clang/lib/Serialization/ASTReaderDecl.cpp
index 547eb77930b4ee..e3ffcee6988b3f 100644
--- a/clang/lib/Serialization/ASTReaderDecl.cpp
+++ b/clang/lib/Serialization/ASTReaderDecl.cpp
@@ -274,9 +274,10 @@ namespace clang {
       // FIXME: We should avoid this pattern of getting the ASTContext.
       ASTContext &C = D->getASTContext();
 
-      auto *&LazySpecializations = D->getCommonPtr()->LazySpecializations;
+      auto *&ExternalSpecializations =
+          D->getCommonPtr()->ExternalSpecializations;
 
-      if (auto &Old = LazySpecializations) {
+      if (auto &Old = ExternalSpecializations) {
         IDs.insert(IDs.end(), Old + 1, Old + 1 + Old[0]);
         llvm::sort(IDs);
         IDs.erase(std::unique(IDs.begin(), IDs.end()), IDs.end());
@@ -286,7 +287,7 @@ namespace clang {
       *Result = IDs.size();
       std::copy(IDs.begin(), IDs.end(), Result + 1);
 
-      LazySpecializations = Result;
+      ExternalSpecializations = Result;
     }
 
     template <typename DeclT>
@@ -426,6 +427,9 @@ namespace clang {
 
     std::pair<uint64_t, uint64_t> VisitDeclContext(DeclContext *DC);
 
+    void ReadSpecializations(ModuleFile &M, Decl *D,
+                             llvm::BitstreamCursor &Cursor);
+
     template<typename T>
     RedeclarableResult VisitRedeclarable(Redeclarable<T> *D);
 
@@ -2431,10 +2435,14 @@ void ASTDeclReader::VisitClassTemplateDecl(ClassTemplateDecl *D) {
   mergeRedeclarableTemplate(D, Redecl);
 
   if (ThisDeclID == Redecl.getFirstID()) {
-    // This ClassTemplateDecl owns a CommonPtr; read it to keep track of all of
-    // the specializations.
+    // This ClassTemplateDecl owns a CommonPtr; read it to keep track of all
+    // of the specializations.
     SmallVector<serialization::DeclID, 32> SpecIDs;
     readDeclIDList(SpecIDs);
+
+    if (Record.readInt())
+      ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor);
+
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
   }
 
@@ -2463,6 +2471,10 @@ void ASTDeclReader::VisitVarTemplateDecl(VarTemplateDecl *D) {
     // the specializations.
     SmallVector<serialization::DeclID, 32> SpecIDs;
     readDeclIDList(SpecIDs);
+
+    if (Record.readInt())
+      ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor);
+
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
   }
 }
@@ -2566,6 +2578,9 @@ void ASTDeclReader::VisitFunctionTemplateDecl(FunctionTemplateDecl *D) {
     SmallVector<serialization::DeclID, 32> SpecIDs;
     readDeclIDList(SpecIDs);
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
+
+    if (Record.readInt())
+      ReadSpecializations(*Loc.F, D, Loc.F->DeclsCursor);
   }
 }
 
@@ -2755,6 +2770,14 @@ ASTDeclReader::VisitDeclContext(DeclContext *DC) {
   return std::make_pair(LexicalOffset, VisibleOffset);
 }
 
+void ASTDeclReader::ReadSpecializations(ModuleFile &M, Decl *D,
+                                        llvm::BitstreamCursor &DeclsCursor) {
+  uint64_t Offset = ReadLocalOffset();
+  bool Failed = Reader.ReadSpecializations(M, DeclsCursor, Offset, D);
+  (void)Failed;
+  assert(!Failed);
+}
+
 template <typename T>
 ASTDeclReader::RedeclarableResult
 ASTDeclReader::VisitRedeclarable(Redeclarable<T> *D) {
@@ -3800,6 +3823,7 @@ Decl *ASTReader::ReadDeclRecord(DeclID ID) {
   switch ((DeclCode)MaybeDeclCode.get()) {
   case DECL_CONTEXT_LEXICAL:
   case DECL_CONTEXT_VISIBLE:
+  case DECL_SPECIALIZATIONS:
     llvm_unreachable("Record cannot be de-serialized with readDeclRecord");
   case DECL_TYPEDEF:
     D = TypedefDecl::CreateDeserialized(Context, ID);
@@ -4112,6 +4136,7 @@ Decl *ASTReader::ReadDeclRecord(DeclID ID) {
         ReadVisibleDeclContextStorage(*Loc.F, DeclsCursor, Offsets.second, ID))
       return nullptr;
   }
+
   assert(Record.getIdx() == Record.size());
 
   // Load any relevant update records.
diff --git a/clang/lib/Serialization/ASTReaderInternals.h b/clang/lib/Serialization/ASTReaderInternals.h
index 25a46ddabcb707..9ac3c139815b08 100644
--- a/clang/lib/Serialization/ASTReaderInternals.h
+++ b/clang/lib/Serialization/ASTReaderInternals.h
@@ -119,6 +119,86 @@ struct DeclContextLookupTable {
   MultiOnDiskHashTable<ASTDeclContextNameLookupTrait> Table;
 };
 
+/// Class that performs lookup to specialized decls.
+class SpecializationsLookupTrait {
+  ASTReader &Reader;
+  ModuleFile &F;
+
+public:
+  // Maximum number of lookup tables we allow before condensing the tables.
+  static const int MaxTables = 4;
+
+  /// The lookup result is a list of global declaration IDs.
+  using data_type = SmallVector<DeclID, 4>;
+
+  struct data_type_builder {
+    data_type &Data;
+    llvm::DenseSet<DeclID> Found;
+
+    data_type_builder(data_type &D) : Data(D) {}
+
+    void insert(DeclID ID) {
+      // Just use a linear scan unless we have more than a few IDs.
+      if (Found.empty() && !Data.empty()) {
+        if (Data.size() <= 4) {
+          for (auto I : Found)
+            if (I == ID)
+              return;
+          Data.push_back(ID);
+          return;
+        }
+
+        // Switch to tracking found IDs in the set.
+        Found.insert(Data.begin(), Data.end());
+      }
+
+      if (Found.insert(ID).second)
+        Data.push_back(ID);
+    }
+  };
+  using hash_value_type = unsigned;
+  using offset_type = unsigned;
+  using file_type = ModuleFile *;
+
+  using external_key_type = unsigned;
+  using internal_key_type = unsigned;
+
+  explicit SpecializationsLookupTrait(ASTReader &Reader, ModuleFile &F)
+      : Reader(Reader), F(F) {}
+
+  static bool EqualKey(const internal_key_type &a, const internal_key_type &b) {
+    return a == b;
+  }
+
+  static hash_value_type ComputeHash(const internal_key_type &Key) {
+    return Key;
+  }
+
+  static internal_key_type GetInternalKey(const external_key_type &Name) {
+    return Name;
+  }
+
+  static std::pair<unsigned, unsigned>
+  ReadKeyDataLength(const unsigned char *&d);
+
+  internal_key_type ReadKey(const unsigned char *d, unsigned);
+
+  void ReadDataInto(internal_key_type, const unsigned char *d, unsigned DataLen,
+                    data_type_builder &Val);
+
+  static void MergeDataInto(const data_type &From, data_type_builder &To) {
+    To.Data.reserve(To.Data.size() + From.size());
+    for (DeclID ID : From)
+      To.insert(ID);
+  }
+
+  file_type ReadFileRef(const unsigned char *&d);
+};
+
+struct SpecializationsLookupTable {
+  MultiOnDiskHashTable<SpecializationsLookupTrait> Table;
+};
+
 /// Base class for the trait describing the on-disk hash table for the
 /// identifiers in an AST file.
 ///
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 78939bfd533ffa..ed58e71b01c09a 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -29,6 +29,7 @@
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/LambdaCapture.h"
 #include "clang/AST/NestedNameSpecifier.h"
+#include "clang/AST/ODRHash.h"
 #include "clang/AST/OpenMPClause.h"
 #include "clang/AST/RawCommentList.h"
 #include "clang/AST/TemplateName.h"
@@ -3924,6 +3925,147 @@ class ASTDeclContextNameLookupTrait {
 
 } // namespace
 
+namespace {
+class SpecializationsLookupTrait {
+  ASTWriter &Writer;
+  llvm::SmallVector<DeclID, 64> DeclIDs;
+
+public:
+  using key_type = unsigned;
+  using key_type_ref = key_type;
+
+  /// A start and end index into DeclIDs, representing a sequence of decls.
+  using data_type = std::pair<unsigned, unsigned>;
+  using data_type_ref = const data_type &;
+
+  using hash_value_type = unsigned;
+  using offset_type = unsigned;
+
+  explicit SpecializationsLookupTrait(ASTWriter &Writer) : Writer(Writer) {}
+
+  template <typename Col> data_type getData(Col &&C) {
+    unsigned Start = DeclIDs.size();
+    for (auto *D : C)
+      DeclIDs.push_back(Writer.GetDeclRef(getDeclForLocalLookup(
+          Writer.getLangOpts(), const_cast<NamedDecl *>(D))));
+    return std::make_pair(Start, DeclIDs.size());
+  }
+
+  data_type
+  ImportData(const reader::SpecializationsLookupTrait::data_type &FromReader) {
+    unsigned Start = DeclIDs.size();
+    for (auto ID : FromReader)
+      DeclIDs.push_back(ID);
+    return std::make_pair(Start, DeclIDs.size());
+  }
+
+  static bool EqualKey(key_type_ref a, key_type_ref b) { return a == b; }
+
+  hash_value_type ComputeHash(key_type Name) { return Name; }
+
+  void EmitFileRef(raw_ostream &Out, ModuleFile *F) const {
+    assert(Writer.hasChain() &&
+           "have reference to loaded module file but no chain?");
+
+    using namespace llvm::support;
+    endian::write<uint32_t>(Out, Writer.getChain()->getModuleFileID(F),
+                            llvm::endianness::little);
+  }
+
+  std::pair<unsigned, unsigned> EmitKeyDataLength(raw_ostream &Out,
+                                                  key_type HashValue,
+                                                  data_type_ref Lookup) {
+    // 4 bytes for each slot.
+    unsigned KeyLen = 4;
+    unsigned DataLen = 4 * (Lookup.second - Lookup.first);
+
+    return emitULEBKeyDataLength(KeyLen, DataLen, Out);
+  }
+
+  void EmitKey(raw_ostream &Out, key_type HashValue, unsigned) {
+    using namespace llvm::support;
+
+    endian::Writer LE(Out, llvm::endianness::little);
+    LE.write<uint32_t>(HashValue);
+  }
+
+  void EmitData(raw_ostream &Out, key_type_ref, data_type Lookup,
+                unsigned DataLen) {
+    using namespace llvm::support;
+
+    endian::Writer LE(Out, llvm::endianness::little);
+    uint64_t Start = Out.tell();
+    (void)Start;
+    for (unsigned I = Lookup.first, N = Lookup.second; I != N; ++I)
+      LE.write<uint32_t>(DeclIDs[I]);
+    assert(Out.tell() - Start == DataLen && "Data length is wrong");
+  }
+};
+
+unsigned CalculateODRHashForSpecs(const Decl *Spec) {
+  assert(!isa<ClassTemplatePartialSpecializationDecl>(Spec) &&
+         !isa<VarTemplatePartialSpecializationDecl>(Spec) &&
+         "We shouldn't see partial specializations here.");
+
+  if (auto *FD = dyn_cast<FunctionDecl>(Spec)) {
+    auto *FDInfo = FD->getTemplateSpecializationInfo();
+    assert(FDInfo);
+    return GetTemplateArgsStableHash(FDInfo->TemplateArguments->asArray());
+  }
+
+  if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(Spec))
+    return GetTemplateArgsStableHash(CTSD->getTemplateArgs().asArray());
+
+  if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(Spec))
+    return GetTemplateArgsStableHash(VTSD->getTemplateArgs().asArray());
+
+  llvm_unreachable("Unimaged specialization kind?");
+}
+} // namespace
+
+uint64_t ASTWriter::WriteSpecializationsLookupTable(
+    const NamedDecl *D,
+    llvm::SmallVectorImpl<const NamedDecl *> &Specializations) {
+  assert(D->isFirstDecl());
+
+  // Create the on-disk hash table representation.
+  MultiOnDiskHashTableGenerator<reader::SpecializationsLookupTrait,
+                                SpecializationsLookupTrait>
+      Generator;
+  SpecializationsLookupTrait Trait(*this);
+
+  llvm::DenseMap<unsigned, llvm::SmallVector<const NamedDecl *, 4>>
+      SpecializationMaps;
+
+  for (auto *Specialization : Specializations) {
+    unsigned HashedValue = CalculateODRHashForSpecs(Specialization);
+
+    auto Iter = SpecializationMaps.find(HashedValue);
+    if (Iter == SpecializationMaps.end())
+      Iter = SpecializationMaps
+                 .try_emplace(HashedValue,
+                              llvm::SmallVector<const NamedDecl *, 4>())
+                 .first;
+
+    Iter->second.push_back(Specialization);
+  }
+
+  for (auto Iter : SpecializationMaps)
+    Generator.insert(Iter.first, Trait.getData(Iter.second), Trait);
+
+  uint64_t Offset = Stream.GetCurrentBitNo();
+
+  auto *Lookups =
+      Chain ? Chain->getLoadedSpecializationsLookupTables(D) : nullptr;
+  llvm::SmallString<4096> LookupTable;
+  Generator.emit(LookupTable, Trait, Lookups ? &Lookups->Table : nullptr);
+
+  RecordData::value_type Record[] = {DECL_SPECIALIZATIONS};
+  Stream.EmitRecordWithBlob(DeclSpecializationsAbbrev, Record, LookupTable);
+
+  return Offset;
+}
+
 bool ASTWriter::isLookupResultExternal(StoredDeclsList &Result,
                                        DeclContext *DC) {
   return Result.hasExternalDecls() &&
@@ -5074,7 +5216,7 @@ ASTFileSignature ASTWriter::WriteASTCore(Sema &SemaRef, StringRef isysroot,
 
   // Keep writing types, declarations, and declaration update records
   // until we've emitted all of them.
-  Stream.EnterSubblock(DECLTYPES_BLOCK_ID, /*bits for abbreviations*/5);
+  Stream.EnterSubblock(DECLTYPES_BLOCK_ID, /*bits for abbreviations*/ 6);
   DeclTypesBlockStartOffset = Stream.GetCurrentBitNo();
   WriteTypeAbbrevs();
   WriteDeclAbbrevs();
diff --git a/clang/lib/Serialization/ASTWriterDecl.cpp b/clang/lib/Serialization/ASTWriterDecl.cpp
index 9e3299f0491848..db3800025f8d27 100644
--- a/clang/lib/Serialization/ASTWriterDecl.cpp
+++ b/clang/lib/Serialization/ASTWriterDecl.cpp
@@ -207,22 +207,26 @@ namespace clang {
       return std::nullopt;
     }
 
-    template<typename DeclTy>
-    void AddTemplateSpecializations(DeclTy *D) {
+    template <typename DeclTy>
+    void AddTemplateSpecializations(
+        DeclTy *D, llvm::SmallVectorImpl<const NamedDecl *> &OptionalSpecs) {
       auto *Common = D->getCommonPtr();
 
       // If we have any lazy specializations, and the external AST source is
       // our chained AST reader, we can just write out the DeclIDs. Otherwise,
       // we need to resolve them to actual declarations.
       if (Writer.Chain != Writer.Context->getExternalSource() &&
-          Common->LazySpecializations) {
-        D->LoadLazySpecializations();
-        assert(!Common->LazySpecializations);
+          Common->ExternalSpecializations) {
+        D->loadExternalSpecializations();
+        assert(!Common->ExternalSpecializations);
       }
 
-      ArrayRef<DeclID> LazySpecializations;
-      if (auto *LS = Common->LazySpecializations)
-        LazySpecializations = llvm::ArrayRef(LS + 1, LS[0]);
+      for (auto &Entry : Common->Specializations)
+        OptionalSpecs.push_back(getSpecializationDecl(Entry));
+
+      ArrayRef<DeclID> ExternalSpecializations;
+      if (auto *LS = Common->ExternalSpecializations)
+        ExternalSpecializations = llvm::ArrayRef(LS + 1, LS[0]);
 
       // Add a slot to the record for the number of specializations.
       unsigned I = Record.size();
@@ -230,17 +234,20 @@ namespace clang {
 
       // AddFirstDeclFromEachModule might trigger deserialization, invalidating
       // *Specializations iterators.
-      llvm::SmallVector<const Decl*, 16> Specs;
-      for (auto &Entry : Common->Specializations)
-        Specs.push_back(getSpecializationDecl(Entry));
+      //
+      // We need to load all the partial specializations at once if the template
+      // required. Since we can't know if a partial specializations will be
+      // needed before resolving a request to instantiate the template.
+      llvm::SmallVector<const Decl *, 16> PartialSpecs;
       for (auto &Entry : getPartialSpecializations(Common))
-        Specs.push_back(getSpecializationDecl(Entry));
+        PartialSpecs.push_back(getSpecializationDecl(Entry));
 
-      for (auto *D : Specs) {
+      for (auto *D : PartialSpecs) {
         assert(D->isCanonicalDecl() && "non-canonical decl in set");
         AddFirstDeclFromEachModule(D, /*IncludeLocal*/true);
       }
-      Record.append(LazySpecializations.begin(), LazySpecializations.end());
+      Record.append(ExternalSpecializations.begin(),
+                    ExternalSpecializations.end());
 
       // Update the size entry we added earlier.
       Record[I] = Record.size() - I - 1;
@@ -1670,8 +1677,16 @@ void ASTDeclWriter::VisitRedeclarableTemplateDecl(RedeclarableTemplateDecl *D) {
 void ASTDeclWriter::VisitClassTemplateDecl(ClassTemplateDecl *D) {
   VisitRedeclarableTemplateDecl(D);
 
-  if (D->isFirstDecl())
-    AddTemplateSpecializations(D);
+  if (D->isFirstDecl()) {
+    llvm::SmallVector<const NamedDecl *, 16> OptionalSpecs;
+    AddTemplateSpecializations(D, OptionalSpecs);
+    if (!OptionalSpecs.empty()) {
+      Record.push_back(1);
+      Record.AddOffset(
+          Writer.WriteSpecializationsLookupTable(D, OptionalSpecs));
+    } else
+      Record.push_back(0);
+  }
   Code = serialization::DECL_CLASS_TEMPLATE;
 }
 
@@ -1730,8 +1745,17 @@ void ASTDeclWriter::VisitClassTemplatePartialSpecializationDecl(
 void ASTDeclWriter::VisitVarTemplateDecl(VarTemplateDecl *D) {
   VisitRedeclarableTemplateDecl(D);
 
-  if (D->isFirstDecl())
-    AddTemplateSpecializations(D);
+  if (D->isFirstDecl()) {
+    llvm::SmallVector<const NamedDecl *, 16> OptionalSpecs;
+    AddTemplateSpecializations(D, OptionalSpecs);
+    if (!OptionalSpecs.empty()) {
+      Record.push_back(1);
+      Record.AddOffset(
+          Writer.WriteSpecializationsLookupTable(D, OptionalSpecs));
+    } else
+      Record.push_back(0);
+  }
+
   Code = serialization::DECL_VAR_TEMPLATE;
 }
 
@@ -1791,8 +1815,17 @@ void ASTDeclWriter::VisitVarTemplatePartialSpecializationDecl(
 void ASTDeclWriter::VisitFunctionTemplateDecl(FunctionTemplateDecl *D) {
   VisitRedeclarableTemplateDecl(D);
 
-  if (D->isFirstDecl())
-    AddTemplateSpecializations(D);
+  if (D->isFirstDecl()) {
+    llvm::SmallVector<const NamedDecl *, 16> OptionalSpecs;
+    AddTemplateSpecializations(D, OptionalSpecs);
+    if (!OptionalSpecs.empty()) {
+      Record.push_back(1);
+      Record.AddOffset(
+          Writer.WriteSpecializationsLookupTable(D, OptionalSpecs));
+    } else
+      Record.push_back(0);
+  }
+
   Code = serialization::DECL_FUNCTION_TEMPLATE;
 }
 
@@ -2657,6 +2690,11 @@ void ASTWriter::WriteDeclAbbrevs() {
   Abv->Add(BitCodeAbbrevOp(serialization::DECL_CONTEXT_VISIBLE));
   Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Blob));
   DeclContextVisibleLookupAbbrev = Stream.EmitAbbrev(std::move(Abv));
+
+  Abv = std::make_shared<BitCodeAbbrev>();
+  Abv->Add(BitCodeAbbrevOp(serialization::DECL_SPECIALIZATIONS));
+  Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Blob));
+  DeclSpecializationsAbbrev = Stream.EmitAbbrev(std::move(Abv));
 }
 
 /// isRequiredDecl - Check if this is a "required" Decl, which must be seen by
diff --git a/clang/test/Modules/odr_hash.cpp b/clang/test/Modules/odr_hash.cpp
index 220ef767df849a..4483cf21e5c623 100644
--- a/clang/test/Modules/odr_hash.cpp
+++ b/clang/test/Modules/odr_hash.cpp
@@ -2897,7 +2897,7 @@ struct S5 {
 };
 #else
 S5 s5;
-// expected-error at second.h:* {{'PointersAndReferences::S5::x' from module 'SecondModule' is not present in definition of 'PointersAndReferences::S5' in module 'FirstModule'}}
+// expected-error at first.h:* {{'PointersAndReferences::S5::x' from module 'FirstModule' is not present in definition of 'PointersAndReferences::S5' in module 'SecondModule'}}
 // expected-note at first.h:* {{declaration of 'x' does not match}}
 #endif
 
@@ -3834,7 +3834,7 @@ struct Valid {
 #else
 Invalid::L2<1>::L3<1> invalid;
 // expected-error at second.h:* {{'Types::InjectedClassName::Invalid::L2::L3::x' from module 'SecondModule' is not present in definition of 'L3<>' in module 'FirstModule'}}
-// expected-note at first.h:* {{declaration of 'x' does not match}}
+// expected-note at second.h:* {{declaration of 'x' does not match}}
 Valid::L2<1>::L3<1> valid;
 #endif
 }  // namespace InjectedClassName
diff --git a/clang/unittests/Serialization/CMakeLists.txt b/clang/unittests/Serialization/CMakeLists.txt
index 10d7de970c643d..6276bbb6d0cb4b 100644
--- a/clang/unittests/Serialization/CMakeLists.txt
+++ b/clang/unittests/Serialization/CMakeLists.txt
@@ -8,6 +8,7 @@ set(LLVM_LINK_COMPONENTS
 add_clang_unittest(SerializationTests
   ForceCheckFileInputTest.cpp
   InMemoryModuleCacheTest.cpp
+  LoadSpecLazily.cpp
   ModuleCacheTest.cpp
   NoCommentsTest.cpp
   SourceLocationEncodingTest.cpp
diff --git a/clang/unittests/Serialization/LoadSpecLazily.cpp b/clang/unittests/Serialization/LoadSpecLazily.cpp
new file mode 100644
index 00000000000000..03f3ff3f786555
--- /dev/null
+++ b/clang/unittests/Serialization/LoadSpecLazily.cpp
@@ -0,0 +1,159 @@
+//== unittests/Serialization/LoadSpecLazily.cpp ----------------------========//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Frontend/FrontendAction.h"
+#include "clang/Frontend/FrontendActions.h"
+#include "clang/Parse/ParseAST.h"
+#include "clang/Serialization/ASTDeserializationListener.h"
+#include "clang/Tooling/Tooling.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using namespace clang;
+using namespace clang::tooling;
+
+namespace {
+
+class LoadSpecLazilyTest : public ::testing::Test {
+  void SetUp() override {
+    ASSERT_FALSE(
+        sys::fs::createUniqueDirectory("load-spec-lazily-test", TestDir));
+  }
+
+  void TearDown() override { sys::fs::remove_directories(TestDir); }
+
+public:
+  SmallString<256> TestDir;
+
+  void addFile(StringRef Path, StringRef Contents) {
+    ASSERT_FALSE(sys::path::is_absolute(Path));
+
+    SmallString<256> AbsPath(TestDir);
+    sys::path::append(AbsPath, Path);
+
+    ASSERT_FALSE(
+        sys::fs::create_directories(llvm::sys::path::parent_path(AbsPath)));
+
+    std::error_code EC;
+    llvm::raw_fd_ostream OS(AbsPath, EC);
+    ASSERT_FALSE(EC);
+    OS << Contents;
+  }
+
+  std::string GenerateModuleInterface(StringRef ModuleName,
+                                      StringRef Contents) {
+    std::string FileName = llvm::Twine(ModuleName + ".cppm").str();
+    addFile(FileName, Contents);
+
+    IntrusiveRefCntPtr<DiagnosticsEngine> Diags =
+        CompilerInstance::createDiagnostics(new DiagnosticOptions());
+    CreateInvocationOptions CIOpts;
+    CIOpts.Diags = Diags;
+    CIOpts.VFS = llvm::vfs::createPhysicalFileSystem();
+
+    std::string CacheBMIPath =
+        llvm::Twine(TestDir + "/" + ModuleName + " .pcm").str();
+    std::string PrebuiltModulePath =
+        "-fprebuilt-module-path=" + TestDir.str().str();
+    const char *Args[] = {"clang++",
+                          "-std=c++20",
+                          "--precompile",
+                          PrebuiltModulePath.c_str(),
+                          "-working-directory",
+                          TestDir.c_str(),
+                          "-I",
+                          TestDir.c_str(),
+                          FileName.c_str(),
+                          "-o",
+                          CacheBMIPath.c_str()};
+    std::shared_ptr<CompilerInvocation> Invocation =
+        createInvocation(Args, CIOpts);
+    EXPECT_TRUE(Invocation);
+
+    CompilerInstance Instance;
+    Instance.setDiagnostics(Diags.get());
+    Instance.setInvocation(Invocation);
+    GenerateModuleInterfaceAction Action;
+    EXPECT_TRUE(Instance.ExecuteAction(Action));
+    EXPECT_FALSE(Diags->hasErrorOccurred());
+
+    return CacheBMIPath;
+  }
+};
+
+class DeclsReaderListener : public ASTDeserializationListener {
+public:
+  void DeclRead(serialization::DeclID ID, const Decl *D) override {
+    auto *ND = dyn_cast<NamedDecl>(D);
+    if (!ND)
+      return;
+
+    EXPECT_FALSE(ND->getName().contains(ForbiddenName));
+  }
+
+  DeclsReaderListener(StringRef ForbiddenName) : ForbiddenName(ForbiddenName) {}
+
+  StringRef ForbiddenName;
+};
+
+class LoadSpecLazilyConsumer : public ASTConsumer {
+  DeclsReaderListener Listener;
+
+public:
+  LoadSpecLazilyConsumer(StringRef ForbiddenName) : Listener(ForbiddenName) {}
+
+  ASTDeserializationListener *GetASTDeserializationListener() override {
+    return &Listener;
+  }
+};
+
+class CheckLoadSpecLazilyAction : public ASTFrontendAction {
+  StringRef ForbiddenName;
+
+public:
+  std::unique_ptr<ASTConsumer>
+  CreateASTConsumer(CompilerInstance &CI, StringRef /*Unused*/) override {
+    return std::make_unique<LoadSpecLazilyConsumer>(ForbiddenName);
+  }
+
+  CheckLoadSpecLazilyAction(StringRef ForbiddenName)
+      : ForbiddenName(ForbiddenName) {}
+};
+
+TEST_F(LoadSpecLazilyTest, BasicTest) {
+  GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+
+export class ShouldNotBeLoaded {};
+
+export class Temp {
+   A<ShouldNotBeLoaded> AS;
+};
+  )cpp");
+
+  const char *test_file_contents = R"cpp(
+import M;
+A<int> a;
+  )cpp";
+  std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+  EXPECT_TRUE(runToolOnCodeWithArgs(
+      std::make_unique<CheckLoadSpecLazilyAction>("ShouldNotBeLoaded"),
+      test_file_contents,
+      {
+          "-std=c++20",
+          DepArg.c_str(),
+          "-I",
+          TestDir.c_str(),
+      },
+      "test.cpp"));
+}
+
+} // namespace



More information about the cfe-commits mailing list