[clang] [Serialization] Load Specializations Lazily (1/2) (PR #76774)

Chuanqi Xu via cfe-commits cfe-commits at lists.llvm.org
Tue Jan 2 19:26:58 PST 2024


https://github.com/ChuanqiXu9 created https://github.com/llvm/llvm-project/pull/76774

The idea comes from @vgvassilev and @vgvassilev had patch for it on phab. Unfortunately phab is closed and I forgot the Dxxx number of that patch. But I remember the last comment from @vgvassilev is that we should use MultiOnDiskHashTable for it. So I followed that and rewrite the whole from the scratch in the new year.

### Background

Currently all the specializations of a template (including instantiation, specialization and partial specializations)  will be loaded at once if we want to instantiate another instance for the template, or find instantiation for the template, or just want to complete the redecl chain. 

This means basically we need to load every specializations for the template once the template declaration got loaded. This is bad since when we load a specialization, we need to load all of its template arguments. Then we have to deserialize a lot of unnecessary declarations.

For example,

```
// M.cppm
export module M;
export template <class T>
class A {};

export class ShouldNotBeLoaded {};

export class Temp {
   A<ShouldNotBeLoaded> AS;
};

// use.cpp
import M;
A<int> a;
```

We should a specialization ` A<ShouldNotBeLoaded>` in `M.cppm` and we instantiate the template `A` in `use.cpp`. Then we will deserialize `ShouldNotBeLoaded` surprisingly when compiling `use.cpp`. And this patch tries to avoid that.

 Given that the templates are heavily used in C++, this is a pain point for the performance.

### What this patch did

This patch adds MultiOnDiskHashTable for specializations in the ASTReader. Then we will only deserialize the specializations with the same template arguments. We made that by using ODRHash for the template arguments as the key of the hash table.

The partial specializations are not added to the MultiOnDiskHashTable. Since we can't know if a partial specialization is needed before deciding the template declaration for a instantiation request. There may be space for further optimizations, but let's do that in the future.

To review this patch, I think `ASTReaderDecl::AddLazySpecializations` may be a good entry point. 

### What this patch not did

This patch doesn't solve the problem completely. Since we will add `update` specializations if there are new specializations in a different module:
https://github.com/llvm/llvm-project/blob/8ae73fea3a2cbb072bf3e577dc49deb25b56e760/clang/lib/Serialization/ASTWriterDecl.cpp#L251-L269

That said, we can't handle this case now:

```
// M.cppm
export module M;
export template <class T>
class A {};

// N.cppm
export module N;
export import A;
export class ShouldNotBeLoaded {};

export class Temp {
   A<ShouldNotBeLoaded> AS;
};

// use.cpp
import N;
A<int> a;
```

Now `ShouldNotBeLoaded` will still be loaded.

But the current patch is already relatively big. So I want to split it in the next patch. I think the current patch is already self contained.

>From f4191661961428a0f534f527774ac3d5159c5103 Mon Sep 17 00:00:00 2001
From: Chuanqi Xu <yedeng.yd at linux.alibaba.com>
Date: Tue, 2 Jan 2024 10:43:03 +0800
Subject: [PATCH] Load Specializations Lazily

---
 clang/include/clang/AST/DeclTemplate.h        |  51 ++++--
 clang/include/clang/AST/ExternalASTSource.h   |   5 +
 clang/include/clang/AST/ODRHash.h             |   3 +
 .../clang/Sema/MultiplexExternalSemaSource.h  |   6 +
 .../include/clang/Serialization/ASTBitCodes.h |   3 +
 clang/include/clang/Serialization/ASTReader.h |  19 +++
 clang/include/clang/Serialization/ASTWriter.h |   6 +
 clang/lib/AST/DeclTemplate.cpp                |  66 +++++---
 clang/lib/AST/ExternalASTSource.cpp           |   5 +
 clang/lib/AST/ODRHash.cpp                     |   2 +
 .../lib/Sema/MultiplexExternalSemaSource.cpp  |   6 +
 clang/lib/Serialization/ASTReader.cpp         | 109 +++++++++++-
 clang/lib/Serialization/ASTReaderDecl.cpp     |  33 +++-
 clang/lib/Serialization/ASTReaderInternals.h  |  80 +++++++++
 clang/lib/Serialization/ASTWriter.cpp         | 149 +++++++++++++++-
 clang/lib/Serialization/ASTWriterDecl.cpp     |  75 ++++++---
 clang/test/Modules/odr_hash.cpp               |   4 +-
 .../Modules/static-member-in-templates.cppm   |  52 ++++++
 clang/unittests/Serialization/CMakeLists.txt  |   1 +
 .../Serialization/LoadSpecLazily.cpp          | 159 ++++++++++++++++++
 20 files changed, 769 insertions(+), 65 deletions(-)
 create mode 100644 clang/test/Modules/static-member-in-templates.cppm
 create mode 100644 clang/unittests/Serialization/LoadSpecLazily.cpp

diff --git a/clang/include/clang/AST/DeclTemplate.h b/clang/include/clang/AST/DeclTemplate.h
index 832ad2de6b08a8..ab380f55c038ee 100644
--- a/clang/include/clang/AST/DeclTemplate.h
+++ b/clang/include/clang/AST/DeclTemplate.h
@@ -30,6 +30,7 @@
 #include "llvm/ADT/FoldingSet.h"
 #include "llvm/ADT/PointerIntPair.h"
 #include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/iterator.h"
 #include "llvm/ADT/iterator_range.h"
 #include "llvm/Support/Casting.h"
@@ -525,8 +526,11 @@ class FunctionTemplateSpecializationInfo final
     return Function.getInt();
   }
 
+  void loadExternalRedecls();
+
 public:
   friend TrailingObjects;
+  friend class ASTReader;
 
   static FunctionTemplateSpecializationInfo *
   Create(ASTContext &C, FunctionDecl *FD, FunctionTemplateDecl *Template,
@@ -789,13 +793,15 @@ class RedeclarableTemplateDecl : public TemplateDecl,
     return SpecIterator<EntryType>(isEnd ? Specs.end() : Specs.begin());
   }
 
-  void loadLazySpecializationsImpl() const;
+  void loadExternalSpecializations() const;
 
   template <class EntryType, typename ...ProfileArguments>
   typename SpecEntryTraits<EntryType>::DeclType*
   findSpecializationImpl(llvm::FoldingSetVector<EntryType> &Specs,
                          void *&InsertPos, ProfileArguments &&...ProfileArgs);
 
+  void loadLazySpecializationsWithArgs(ArrayRef<TemplateArgument> TemplateArgs);
+
   template <class Derived, class EntryType>
   void addSpecializationImpl(llvm::FoldingSetVector<EntryType> &Specs,
                              EntryType *Entry, void *InsertPos);
@@ -814,9 +820,13 @@ class RedeclarableTemplateDecl : public TemplateDecl,
     /// If non-null, points to an array of specializations (including
     /// partial specializations) known only by their external declaration IDs.
     ///
+    /// These specializations needs to be loaded at once in
+    /// loadExternalSpecializations to complete the redecl chain or be preparing
+    /// for template resolution.
+    ///
     /// The first value in the array is the number of specializations/partial
     /// specializations that follow.
-    uint32_t *LazySpecializations = nullptr;
+    uint32_t *ExternalSpecializations = nullptr;
 
     /// The set of "injected" template arguments used within this
     /// template.
@@ -850,6 +860,8 @@ class RedeclarableTemplateDecl : public TemplateDecl,
   friend class ASTDeclWriter;
   friend class ASTReader;
   template <class decl_type> friend class RedeclarableTemplate;
+  friend class ClassTemplateSpecializationDecl;
+  friend class VarTemplateSpecializationDecl;
 
   /// Retrieves the canonical declaration of this template.
   RedeclarableTemplateDecl *getCanonicalDecl() override {
@@ -977,6 +989,12 @@ SpecEntryTraits<FunctionTemplateSpecializationInfo> {
 class FunctionTemplateDecl : public RedeclarableTemplateDecl {
 protected:
   friend class FunctionDecl;
+  friend class FunctionTemplateSpecializationInfo;
+
+  template <typename DeclTy>
+  friend void GetSpecializationsImpl(const DeclTy *,
+                                     llvm::SmallPtrSetImpl<const NamedDecl *> &,
+                                     ASTReader *Reader);
 
   /// Data that is common to all of the declarations of a given
   /// function template.
@@ -1012,13 +1030,13 @@ class FunctionTemplateDecl : public RedeclarableTemplateDecl {
   void addSpecialization(FunctionTemplateSpecializationInfo* Info,
                          void *InsertPos);
 
+  /// Load any lazily-loaded specializations from the external source.
+  void LoadLazySpecializations() const;
+
 public:
   friend class ASTDeclReader;
   friend class ASTDeclWriter;
 
-  /// Load any lazily-loaded specializations from the external source.
-  void LoadLazySpecializations() const;
-
   /// Get the underlying function declaration of the template.
   FunctionDecl *getTemplatedDecl() const {
     return static_cast<FunctionDecl *>(TemplatedDecl);
@@ -1839,6 +1857,8 @@ class ClassTemplateSpecializationDecl
   LLVM_PREFERRED_TYPE(TemplateSpecializationKind)
   unsigned SpecializationKind : 3;
 
+  void loadExternalRedecls();
+
 protected:
   ClassTemplateSpecializationDecl(ASTContext &Context, Kind DK, TagKind TK,
                                   DeclContext *DC, SourceLocation StartLoc,
@@ -1852,6 +1872,7 @@ class ClassTemplateSpecializationDecl
 public:
   friend class ASTDeclReader;
   friend class ASTDeclWriter;
+  friend class ASTReader;
 
   static ClassTemplateSpecializationDecl *
   Create(ASTContext &Context, TagKind TK, DeclContext *DC,
@@ -2238,6 +2259,11 @@ class ClassTemplatePartialSpecializationDecl
 /// Declaration of a class template.
 class ClassTemplateDecl : public RedeclarableTemplateDecl {
 protected:
+  template <typename DeclTy>
+  friend void GetSpecializationsImpl(const DeclTy *,
+                                     llvm::SmallPtrSetImpl<const NamedDecl *> &,
+                                     ASTReader *Reader);
+
   /// Data that is common to all of the declarations of a given
   /// class template.
   struct Common : CommonBase {
@@ -2285,9 +2311,7 @@ class ClassTemplateDecl : public RedeclarableTemplateDecl {
   friend class ASTDeclReader;
   friend class ASTDeclWriter;
   friend class TemplateDeclInstantiator;
-
-  /// Load any lazily-loaded specializations from the external source.
-  void LoadLazySpecializations() const;
+  friend class ClassTemplateSpecializationDecl;
 
   /// Get the underlying class declarations of the template.
   CXXRecordDecl *getTemplatedDecl() const {
@@ -2651,6 +2675,8 @@ class VarTemplateSpecializationDecl : public VarDecl,
   LLVM_PREFERRED_TYPE(bool)
   unsigned IsCompleteDefinition : 1;
 
+  void loadExternalRedecls();
+
 protected:
   VarTemplateSpecializationDecl(Kind DK, ASTContext &Context, DeclContext *DC,
                                 SourceLocation StartLoc, SourceLocation IdLoc,
@@ -2664,6 +2690,7 @@ class VarTemplateSpecializationDecl : public VarDecl,
 public:
   friend class ASTDeclReader;
   friend class ASTDeclWriter;
+  friend class ASTReader;
   friend class VarDecl;
 
   static VarTemplateSpecializationDecl *
@@ -3018,6 +3045,11 @@ class VarTemplatePartialSpecializationDecl
 /// Declaration of a variable template.
 class VarTemplateDecl : public RedeclarableTemplateDecl {
 protected:
+  template <typename DeclTy>
+  friend void GetSpecializationsImpl(const DeclTy *,
+                                     llvm::SmallPtrSetImpl<const NamedDecl *> &,
+                                     ASTReader *Reader);
+
   /// Data that is common to all of the declarations of a given
   /// variable template.
   struct Common : CommonBase {
@@ -3057,8 +3089,7 @@ class VarTemplateDecl : public RedeclarableTemplateDecl {
   friend class ASTDeclReader;
   friend class ASTDeclWriter;
 
-  /// Load any lazily-loaded specializations from the external source.
-  void LoadLazySpecializations() const;
+  friend class VarTemplatePartialSpecializationDecl;
 
   /// Get the underlying variable declarations of the template.
   VarDecl *getTemplatedDecl() const {
diff --git a/clang/include/clang/AST/ExternalASTSource.h b/clang/include/clang/AST/ExternalASTSource.h
index 8e573965b0a336..7f26afd53106ba 100644
--- a/clang/include/clang/AST/ExternalASTSource.h
+++ b/clang/include/clang/AST/ExternalASTSource.h
@@ -150,6 +150,11 @@ class ExternalASTSource : public RefCountedBase<ExternalASTSource> {
   virtual bool
   FindExternalVisibleDeclsByName(const DeclContext *DC, DeclarationName Name);
 
+  /// Load all the external specialzations for the Decl and the corresponding
+  /// template arguments.
+  virtual void LoadExternalSpecs(const Decl *D,
+                                 ArrayRef<TemplateArgument> TemplateArgs);
+
   /// Ensures that the table of all visible declarations inside this
   /// context is up to date.
   ///
diff --git a/clang/include/clang/AST/ODRHash.h b/clang/include/clang/AST/ODRHash.h
index cedf644520fc32..ddd1bb0f095e75 100644
--- a/clang/include/clang/AST/ODRHash.h
+++ b/clang/include/clang/AST/ODRHash.h
@@ -101,6 +101,9 @@ class ODRHash {
   // Save booleans until the end to lower the size of data to process.
   void AddBoolean(bool value);
 
+  // Add intergers to ID.
+  void AddInteger(unsigned Value);
+
   static bool isSubDeclToBeProcessed(const Decl *D, const DeclContext *Parent);
 
 private:
diff --git a/clang/include/clang/Sema/MultiplexExternalSemaSource.h b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
index 2bf91cb5212c5e..886c3854adac6e 100644
--- a/clang/include/clang/Sema/MultiplexExternalSemaSource.h
+++ b/clang/include/clang/Sema/MultiplexExternalSemaSource.h
@@ -97,6 +97,12 @@ class MultiplexExternalSemaSource : public ExternalSemaSource {
   bool FindExternalVisibleDeclsByName(const DeclContext *DC,
                                       DeclarationName Name) override;
 
+  /// Load all the external specialzations for the Decl and the corresponding
+  /// template args.
+  virtual void
+  LoadExternalSpecs(const Decl *D,
+                    ArrayRef<TemplateArgument> TemplateArgs) override;
+
   /// Ensures that the table of all visible declarations inside this
   /// context is up to date.
   void completeVisibleDeclsMap(const DeclContext *DC) override;
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index fdd64f2abbe937..a1bf3659e91f3e 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1523,6 +1523,9 @@ enum DeclCode {
   /// An ImplicitConceptSpecializationDecl record.
   DECL_IMPLICIT_CONCEPT_SPECIALIZATION,
 
+  // A decls specilization record.
+  DECL_SPECS,
+
   DECL_LAST = DECL_IMPLICIT_CONCEPT_SPECIALIZATION
 };
 
diff --git a/clang/include/clang/Serialization/ASTReader.h b/clang/include/clang/Serialization/ASTReader.h
index 21d791f5cd89a2..52ca6c76db8e37 100644
--- a/clang/include/clang/Serialization/ASTReader.h
+++ b/clang/include/clang/Serialization/ASTReader.h
@@ -340,6 +340,9 @@ class ASTIdentifierLookupTrait;
 /// The on-disk hash table(s) used for DeclContext name lookup.
 struct DeclContextLookupTable;
 
+/// The on-disk hash table(s) used for specialization decls.
+struct SpecializedDeclsLookupTable;
+
 } // namespace reader
 
 } // namespace serialization
@@ -599,6 +602,11 @@ class ASTReader
   llvm::DenseMap<const DeclContext *,
                  serialization::reader::DeclContextLookupTable> Lookups;
 
+  /// Map from decls to specialized decls.
+  llvm::DenseMap<const Decl *,
+                 serialization::reader::SpecializedDeclsLookupTable>
+      SpecLookups;
+
   // Updates for visible decls can occur for other contexts than just the
   // TU, and when we read those update records, the actual context may not
   // be available yet, so have this pending map using the ID as a key. It
@@ -640,6 +648,9 @@ class ASTReader
                                      llvm::BitstreamCursor &Cursor,
                                      uint64_t Offset, serialization::DeclID ID);
 
+  bool ReadDeclsSpecs(ModuleFile &M, llvm::BitstreamCursor &Cursor,
+                      uint64_t Offset, Decl *D);
+
   /// A vector containing identifiers that have already been
   /// loaded.
   ///
@@ -1343,6 +1354,11 @@ class ASTReader
   const serialization::reader::DeclContextLookupTable *
   getLoadedLookupTables(DeclContext *Primary) const;
 
+  /// Get the loaded specializations lookup tables for \p D,
+  /// if any.
+  serialization::reader::SpecializedDeclsLookupTable *
+  getLoadedSpecLookupTables(Decl *D);
+
 private:
   struct ImportedModule {
     ModuleFile *Mod;
@@ -1982,6 +1998,9 @@ class ASTReader
   bool FindExternalVisibleDeclsByName(const DeclContext *DC,
                                       DeclarationName Name) override;
 
+  void LoadExternalSpecs(const Decl *D,
+                         ArrayRef<TemplateArgument> TemplateArgs) override;
+
   /// Read all of the declarations lexically stored in a
   /// declaration context.
   ///
diff --git a/clang/include/clang/Serialization/ASTWriter.h b/clang/include/clang/Serialization/ASTWriter.h
index de69f99003d827..c98beaa1a24dc0 100644
--- a/clang/include/clang/Serialization/ASTWriter.h
+++ b/clang/include/clang/Serialization/ASTWriter.h
@@ -527,6 +527,10 @@ class ASTWriter : public ASTDeserializationListener,
   bool isLookupResultExternal(StoredDeclsList &Result, DeclContext *DC);
   bool isLookupResultEntirelyExternal(StoredDeclsList &Result, DeclContext *DC);
 
+  uint64_t
+  WriteSpecsLookupTable(NamedDecl *D,
+                        llvm::SmallVectorImpl<const NamedDecl *> &Specs);
+
   void GenerateNameLookupTable(const DeclContext *DC,
                                llvm::SmallVectorImpl<char> &LookupTable);
   uint64_t WriteDeclContextLexicalBlock(ASTContext &Context, DeclContext *DC);
@@ -564,6 +568,8 @@ class ASTWriter : public ASTDeserializationListener,
   unsigned DeclEnumAbbrev = 0;
   unsigned DeclObjCIvarAbbrev = 0;
   unsigned DeclCXXMethodAbbrev = 0;
+  unsigned DeclSpecsAbbrev = 0;
+
   unsigned DeclDependentNonTemplateCXXMethodAbbrev = 0;
   unsigned DeclTemplateCXXMethodAbbrev = 0;
   unsigned DeclMemberSpecializedCXXMethodAbbrev = 0;
diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp
index 7d7556e670f951..43c9158fb40413 100644
--- a/clang/lib/AST/DeclTemplate.cpp
+++ b/clang/lib/AST/DeclTemplate.cpp
@@ -331,14 +331,14 @@ RedeclarableTemplateDecl::CommonBase *RedeclarableTemplateDecl::getCommonPtr() c
   return Common;
 }
 
-void RedeclarableTemplateDecl::loadLazySpecializationsImpl() const {
+void RedeclarableTemplateDecl::loadExternalSpecializations() const {
   // Grab the most recent declaration to ensure we've loaded any lazy
   // redeclarations of this template.
   CommonBase *CommonBasePtr = getMostRecentDecl()->getCommonPtr();
-  if (CommonBasePtr->LazySpecializations) {
+  if (CommonBasePtr->ExternalSpecializations) {
     ASTContext &Context = getASTContext();
-    uint32_t *Specs = CommonBasePtr->LazySpecializations;
-    CommonBasePtr->LazySpecializations = nullptr;
+    uint32_t *Specs = CommonBasePtr->ExternalSpecializations;
+    CommonBasePtr->ExternalSpecializations = nullptr;
     for (uint32_t I = 0, N = *Specs++; I != N; ++I)
       (void)Context.getExternalSource()->GetExternalDecl(Specs[I]);
   }
@@ -358,6 +358,15 @@ RedeclarableTemplateDecl::findSpecializationImpl(
   return Entry ? SETraits::getDecl(Entry)->getMostRecentDecl() : nullptr;
 }
 
+void RedeclarableTemplateDecl::loadLazySpecializationsWithArgs(
+    ArrayRef<TemplateArgument> TemplateArgs) {
+  auto *ExternalSource = getASTContext().getExternalSource();
+  if (!ExternalSource)
+    return;
+
+  ExternalSource->LoadExternalSpecs(this->getCanonicalDecl(), TemplateArgs);
+}
+
 template<class Derived, class EntryType>
 void RedeclarableTemplateDecl::addSpecializationImpl(
     llvm::FoldingSetVector<EntryType> &Specializations, EntryType *Entry,
@@ -430,24 +439,23 @@ FunctionTemplateDecl::newCommon(ASTContext &C) const {
   return CommonPtr;
 }
 
-void FunctionTemplateDecl::LoadLazySpecializations() const {
-  loadLazySpecializationsImpl();
-}
-
 llvm::FoldingSetVector<FunctionTemplateSpecializationInfo> &
 FunctionTemplateDecl::getSpecializations() const {
-  LoadLazySpecializations();
+  loadExternalSpecializations();
   return getCommonPtr()->Specializations;
 }
 
 FunctionDecl *
 FunctionTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                          void *&InsertPos) {
+  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getSpecializations(), InsertPos, Args);
 }
 
 void FunctionTemplateDecl::addSpecialization(
       FunctionTemplateSpecializationInfo *Info, void *InsertPos) {
+  using SETraits = SpecEntryTraits<FunctionTemplateSpecializationInfo>;
+  loadLazySpecializationsWithArgs(SETraits::getTemplateArgs(Info));
   addSpecializationImpl<FunctionTemplateDecl>(getSpecializations(), Info,
                                               InsertPos);
 }
@@ -508,19 +516,15 @@ ClassTemplateDecl *ClassTemplateDecl::CreateDeserialized(ASTContext &C,
                                        DeclarationName(), nullptr, nullptr);
 }
 
-void ClassTemplateDecl::LoadLazySpecializations() const {
-  loadLazySpecializationsImpl();
-}
-
 llvm::FoldingSetVector<ClassTemplateSpecializationDecl> &
 ClassTemplateDecl::getSpecializations() const {
-  LoadLazySpecializations();
+  loadExternalSpecializations();
   return getCommonPtr()->Specializations;
 }
 
 llvm::FoldingSetVector<ClassTemplatePartialSpecializationDecl> &
 ClassTemplateDecl::getPartialSpecializations() const {
-  LoadLazySpecializations();
+  loadExternalSpecializations();
   return getCommonPtr()->PartialSpecializations;
 }
 
@@ -534,11 +538,14 @@ ClassTemplateDecl::newCommon(ASTContext &C) const {
 ClassTemplateSpecializationDecl *
 ClassTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                       void *&InsertPos) {
+  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getSpecializations(), InsertPos, Args);
 }
 
 void ClassTemplateDecl::AddSpecialization(ClassTemplateSpecializationDecl *D,
                                           void *InsertPos) {
+  using SETraits = SpecEntryTraits<ClassTemplateSpecializationDecl>;
+  loadLazySpecializationsWithArgs(SETraits::getTemplateArgs(D));
   addSpecializationImpl<ClassTemplateDecl>(getSpecializations(), D, InsertPos);
 }
 
@@ -546,6 +553,7 @@ ClassTemplatePartialSpecializationDecl *
 ClassTemplateDecl::findPartialSpecialization(
     ArrayRef<TemplateArgument> Args,
     TemplateParameterList *TPL, void *&InsertPos) {
+  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getPartialSpecializations(), InsertPos, Args,
                                 TPL);
 }
@@ -900,6 +908,11 @@ FunctionTemplateSpecializationInfo *FunctionTemplateSpecializationInfo::Create(
       FD, Template, TSK, TemplateArgs, ArgsAsWritten, POI, MSInfo);
 }
 
+void FunctionTemplateSpecializationInfo::loadExternalRedecls() {
+  getTemplate()->loadExternalSpecializations();
+  getTemplate()->loadLazySpecializationsWithArgs(TemplateArguments->asArray());
+}
+
 //===----------------------------------------------------------------------===//
 // ClassTemplateSpecializationDecl Implementation
 //===----------------------------------------------------------------------===//
@@ -1024,6 +1037,12 @@ ClassTemplateSpecializationDecl::getSourceRange() const {
   }
 }
 
+void ClassTemplateSpecializationDecl::loadExternalRedecls() {
+  getSpecializedTemplate()->loadExternalSpecializations();
+  getSpecializedTemplate()->loadLazySpecializationsWithArgs(
+      getTemplateArgs().asArray());
+}
+
 //===----------------------------------------------------------------------===//
 // ConceptDecl Implementation
 //===----------------------------------------------------------------------===//
@@ -1226,19 +1245,15 @@ VarTemplateDecl *VarTemplateDecl::CreateDeserialized(ASTContext &C,
                                      DeclarationName(), nullptr, nullptr);
 }
 
-void VarTemplateDecl::LoadLazySpecializations() const {
-  loadLazySpecializationsImpl();
-}
-
 llvm::FoldingSetVector<VarTemplateSpecializationDecl> &
 VarTemplateDecl::getSpecializations() const {
-  LoadLazySpecializations();
+  loadExternalSpecializations();
   return getCommonPtr()->Specializations;
 }
 
 llvm::FoldingSetVector<VarTemplatePartialSpecializationDecl> &
 VarTemplateDecl::getPartialSpecializations() const {
-  LoadLazySpecializations();
+  loadExternalSpecializations();
   return getCommonPtr()->PartialSpecializations;
 }
 
@@ -1252,17 +1267,21 @@ VarTemplateDecl::newCommon(ASTContext &C) const {
 VarTemplateSpecializationDecl *
 VarTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
                                     void *&InsertPos) {
+  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getSpecializations(), InsertPos, Args);
 }
 
 void VarTemplateDecl::AddSpecialization(VarTemplateSpecializationDecl *D,
                                         void *InsertPos) {
+  using SETraits = SpecEntryTraits<VarTemplateSpecializationDecl>;
+  loadLazySpecializationsWithArgs(SETraits::getTemplateArgs(D));
   addSpecializationImpl<VarTemplateDecl>(getSpecializations(), D, InsertPos);
 }
 
 VarTemplatePartialSpecializationDecl *
 VarTemplateDecl::findPartialSpecialization(ArrayRef<TemplateArgument> Args,
      TemplateParameterList *TPL, void *&InsertPos) {
+  loadLazySpecializationsWithArgs(Args);
   return findSpecializationImpl(getPartialSpecializations(), InsertPos, Args,
                                 TPL);
 }
@@ -1393,6 +1412,11 @@ SourceRange VarTemplateSpecializationDecl::getSourceRange() const {
   return VarDecl::getSourceRange();
 }
 
+void VarTemplateSpecializationDecl::loadExternalRedecls() {
+  getSpecializedTemplate()->loadExternalSpecializations();
+  getSpecializedTemplate()->loadLazySpecializationsWithArgs(
+      getTemplateArgs().asArray());
+}
 
 //===----------------------------------------------------------------------===//
 // VarTemplatePartialSpecializationDecl Implementation
diff --git a/clang/lib/AST/ExternalASTSource.cpp b/clang/lib/AST/ExternalASTSource.cpp
index 090ef02aa4224d..1fb74ceb7783e4 100644
--- a/clang/lib/AST/ExternalASTSource.cpp
+++ b/clang/lib/AST/ExternalASTSource.cpp
@@ -100,6 +100,11 @@ ExternalASTSource::FindExternalVisibleDeclsByName(const DeclContext *DC,
   return false;
 }
 
+void ExternalASTSource::LoadExternalSpecs(
+    const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+  return;
+}
+
 void ExternalASTSource::completeVisibleDeclsMap(const DeclContext *DC) {}
 
 void ExternalASTSource::FindExternalLexicalDecls(
diff --git a/clang/lib/AST/ODRHash.cpp b/clang/lib/AST/ODRHash.cpp
index aea1a93ae1fa82..ace24eb4d29d85 100644
--- a/clang/lib/AST/ODRHash.cpp
+++ b/clang/lib/AST/ODRHash.cpp
@@ -1249,3 +1249,5 @@ void ODRHash::AddQualType(QualType T) {
 void ODRHash::AddBoolean(bool Value) {
   Bools.push_back(Value);
 }
+
+void ODRHash::AddInteger(unsigned Value) { ID.AddInteger(Value); }
diff --git a/clang/lib/Sema/MultiplexExternalSemaSource.cpp b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
index 058e22cb2b814e..7d8c7d7a99d645 100644
--- a/clang/lib/Sema/MultiplexExternalSemaSource.cpp
+++ b/clang/lib/Sema/MultiplexExternalSemaSource.cpp
@@ -115,6 +115,12 @@ FindExternalVisibleDeclsByName(const DeclContext *DC, DeclarationName Name) {
   return AnyDeclsFound;
 }
 
+void MultiplexExternalSemaSource::LoadExternalSpecs(
+    const Decl *D, ArrayRef<TemplateArgument> TemplateArgs) {
+  for (size_t i = 0; i < Sources.size(); ++i)
+    Sources[i]->LoadExternalSpecs(D, TemplateArgs);
+}
+
 void MultiplexExternalSemaSource::completeVisibleDeclsMap(const DeclContext *DC){
   for(size_t i = 0; i < Sources.size(); ++i)
     Sources[i]->completeVisibleDeclsMap(DC);
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 9effd333daccdb..67b4d0bddb019e 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -1223,6 +1223,38 @@ void ASTDeclContextNameLookupTrait::ReadDataInto(internal_key_type,
   }
 }
 
+ModuleFile *SpecializedDeclLookupTrait::ReadFileRef(const unsigned char *&d) {
+  using namespace llvm::support;
+
+  uint32_t ModuleFileID =
+      endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+  return Reader.getLocalModuleFile(F, ModuleFileID);
+}
+
+SpecializedDeclLookupTrait::internal_key_type
+SpecializedDeclLookupTrait::ReadKey(const unsigned char *d, unsigned) {
+  using namespace llvm::support;
+  return endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+}
+
+std::pair<unsigned, unsigned>
+SpecializedDeclLookupTrait::ReadKeyDataLength(const unsigned char *&d) {
+  return readULEBKeyDataLength(d);
+}
+
+void SpecializedDeclLookupTrait::ReadDataInto(internal_key_type,
+                                              const unsigned char *d,
+                                              unsigned DataLen,
+                                              data_type_builder &Val) {
+  using namespace llvm::support;
+
+  for (unsigned NumDecls = DataLen / 4; NumDecls; --NumDecls) {
+    uint32_t LocalID =
+        endian::readNext<uint32_t, llvm::endianness::little, unaligned>(d);
+    Val.insert(Reader.getGlobalDeclID(F, LocalID));
+  }
+}
+
 bool ASTReader::ReadLexicalDeclContextStorage(ModuleFile &M,
                                               BitstreamCursor &Cursor,
                                               uint64_t Offset,
@@ -1312,6 +1344,44 @@ bool ASTReader::ReadVisibleDeclContextStorage(ModuleFile &M,
   return false;
 }
 
+bool ASTReader::ReadDeclsSpecs(ModuleFile &M, BitstreamCursor &Cursor,
+                               uint64_t Offset, Decl *D) {
+  assert(Offset != 0);
+
+  SavedStreamPosition SavedPosition(Cursor);
+  if (llvm::Error Err = Cursor.JumpToBit(Offset)) {
+    Error(std::move(Err));
+    return true;
+  }
+
+  RecordData Record;
+  StringRef Blob;
+  Expected<unsigned> MaybeCode = Cursor.ReadCode();
+  if (!MaybeCode) {
+    Error(MaybeCode.takeError());
+    return true;
+  }
+  unsigned Code = MaybeCode.get();
+
+  Expected<unsigned> MaybeRecCode = Cursor.readRecord(Code, Record, &Blob);
+  if (!MaybeRecCode) {
+    Error(MaybeRecCode.takeError());
+    return true;
+  }
+  unsigned RecCode = MaybeRecCode.get();
+  if (RecCode != DECL_SPECS) {
+    Error("Expected decl specs block");
+    return true;
+  }
+
+  auto *Data = (const unsigned char *)Blob.data();
+  D = D->getCanonicalDecl();
+  SpecLookups[D].Table.add(&M, Data,
+                           reader::SpecializedDeclLookupTrait(*this, M));
+
+  return false;
+}
+
 void ASTReader::Error(StringRef Msg) const {
   Error(diag::err_fe_pch_malformed, Msg);
   if (PP.getLangOpts().Modules && !Diags.isDiagnosticInFlight() &&
@@ -7523,13 +7593,13 @@ void ASTReader::CompleteRedeclChain(const Decl *D) {
   }
 
   if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(D))
-    CTSD->getSpecializedTemplate()->LoadLazySpecializations();
+    const_cast<ClassTemplateSpecializationDecl *>(CTSD)->loadExternalRedecls();
   if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(D))
-    VTSD->getSpecializedTemplate()->LoadLazySpecializations();
-  if (auto *FD = dyn_cast<FunctionDecl>(D)) {
-    if (auto *Template = FD->getPrimaryTemplate())
-      Template->LoadLazySpecializations();
-  }
+    const_cast<VarTemplateSpecializationDecl *>(VTSD)->loadExternalRedecls();
+  if (auto *FD = dyn_cast<FunctionDecl>(D))
+    if (auto *FDInfo = FD->getTemplateSpecializationInfo())
+      const_cast<FunctionTemplateSpecializationInfo *>(FDInfo)
+          ->loadExternalRedecls();
 }
 
 CXXCtorInitializer **
@@ -7958,6 +8028,26 @@ ASTReader::FindExternalVisibleDeclsByName(const DeclContext *DC,
   return !Decls.empty();
 }
 
+void ASTReader::LoadExternalSpecs(const Decl *D,
+                                  ArrayRef<TemplateArgument> TemplateArgs) {
+  assert(D);
+
+  auto It = SpecLookups.find(D);
+  if (It == SpecLookups.end())
+    return;
+
+  ODRHash Hasher;
+  Hasher.AddInteger(TemplateArgs.size());
+  for (const TemplateArgument &TemplateArg : TemplateArgs)
+    Hasher.AddTemplateArgument(TemplateArg);
+  auto HashValue = Hasher.CalculateHash();
+
+  Deserializing LookupResults(this);
+
+  for (DeclID ID : It->second.Table.find(HashValue))
+    GetDecl(ID);
+}
+
 void ASTReader::completeVisibleDeclsMap(const DeclContext *DC) {
   if (!DC->hasExternalVisibleStorage())
     return;
@@ -7987,6 +8077,13 @@ ASTReader::getLoadedLookupTables(DeclContext *Primary) const {
   return I == Lookups.end() ? nullptr : &I->second;
 }
 
+serialization::reader::SpecializedDeclsLookupTable *
+ASTReader::getLoadedSpecLookupTables(Decl *D) {
+  assert(D->isCanonicalDecl());
+  auto I = SpecLookups.find(D);
+  return I == SpecLookups.end() ? nullptr : &I->second;
+}
+
 /// Under non-PCH compilation the consumer receives the objc methods
 /// before receiving the implementation, and codegen depends on this.
 /// We simulate this by deserializing and passing to consumer the methods of the
diff --git a/clang/lib/Serialization/ASTReaderDecl.cpp b/clang/lib/Serialization/ASTReaderDecl.cpp
index 547eb77930b4ee..38582088ca0df7 100644
--- a/clang/lib/Serialization/ASTReaderDecl.cpp
+++ b/clang/lib/Serialization/ASTReaderDecl.cpp
@@ -274,9 +274,10 @@ namespace clang {
       // FIXME: We should avoid this pattern of getting the ASTContext.
       ASTContext &C = D->getASTContext();
 
-      auto *&LazySpecializations = D->getCommonPtr()->LazySpecializations;
+      auto *&ExternalSpecializations =
+          D->getCommonPtr()->ExternalSpecializations;
 
-      if (auto &Old = LazySpecializations) {
+      if (auto &Old = ExternalSpecializations) {
         IDs.insert(IDs.end(), Old + 1, Old + 1 + Old[0]);
         llvm::sort(IDs);
         IDs.erase(std::unique(IDs.begin(), IDs.end()), IDs.end());
@@ -286,7 +287,7 @@ namespace clang {
       *Result = IDs.size();
       std::copy(IDs.begin(), IDs.end(), Result + 1);
 
-      LazySpecializations = Result;
+      ExternalSpecializations = Result;
     }
 
     template <typename DeclT>
@@ -426,6 +427,8 @@ namespace clang {
 
     std::pair<uint64_t, uint64_t> VisitDeclContext(DeclContext *DC);
 
+    void ReadDeclsSpecs(ModuleFile &M, Decl *D, llvm::BitstreamCursor &Cursor);
+
     template<typename T>
     RedeclarableResult VisitRedeclarable(Redeclarable<T> *D);
 
@@ -2431,10 +2434,14 @@ void ASTDeclReader::VisitClassTemplateDecl(ClassTemplateDecl *D) {
   mergeRedeclarableTemplate(D, Redecl);
 
   if (ThisDeclID == Redecl.getFirstID()) {
-    // This ClassTemplateDecl owns a CommonPtr; read it to keep track of all of
-    // the specializations.
+    // This ClassTemplateDecl owns a CommonPtr; read it to keep track of all
+    // of the specializations.
     SmallVector<serialization::DeclID, 32> SpecIDs;
     readDeclIDList(SpecIDs);
+
+    if (Record.readInt())
+      ReadDeclsSpecs(*Loc.F, D, Loc.F->DeclsCursor);
+
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
   }
 
@@ -2463,6 +2470,10 @@ void ASTDeclReader::VisitVarTemplateDecl(VarTemplateDecl *D) {
     // the specializations.
     SmallVector<serialization::DeclID, 32> SpecIDs;
     readDeclIDList(SpecIDs);
+
+    if (Record.readInt())
+      ReadDeclsSpecs(*Loc.F, D, Loc.F->DeclsCursor);
+
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
   }
 }
@@ -2566,6 +2577,9 @@ void ASTDeclReader::VisitFunctionTemplateDecl(FunctionTemplateDecl *D) {
     SmallVector<serialization::DeclID, 32> SpecIDs;
     readDeclIDList(SpecIDs);
     ASTDeclReader::AddLazySpecializations(D, SpecIDs);
+
+    if (Record.readInt())
+      ReadDeclsSpecs(*Loc.F, D, Loc.F->DeclsCursor);
   }
 }
 
@@ -2755,6 +2769,13 @@ ASTDeclReader::VisitDeclContext(DeclContext *DC) {
   return std::make_pair(LexicalOffset, VisibleOffset);
 }
 
+void ASTDeclReader::ReadDeclsSpecs(ModuleFile &M, Decl *D,
+                                   llvm::BitstreamCursor &DeclsCursor) {
+  uint64_t Offset = ReadLocalOffset();
+  bool Failed = Reader.ReadDeclsSpecs(M, DeclsCursor, Offset, D);
+  assert(!Failed);
+}
+
 template <typename T>
 ASTDeclReader::RedeclarableResult
 ASTDeclReader::VisitRedeclarable(Redeclarable<T> *D) {
@@ -3800,6 +3821,7 @@ Decl *ASTReader::ReadDeclRecord(DeclID ID) {
   switch ((DeclCode)MaybeDeclCode.get()) {
   case DECL_CONTEXT_LEXICAL:
   case DECL_CONTEXT_VISIBLE:
+  case DECL_SPECS:
     llvm_unreachable("Record cannot be de-serialized with readDeclRecord");
   case DECL_TYPEDEF:
     D = TypedefDecl::CreateDeserialized(Context, ID);
@@ -4112,6 +4134,7 @@ Decl *ASTReader::ReadDeclRecord(DeclID ID) {
         ReadVisibleDeclContextStorage(*Loc.F, DeclsCursor, Offsets.second, ID))
       return nullptr;
   }
+
   assert(Record.getIdx() == Record.size());
 
   // Load any relevant update records.
diff --git a/clang/lib/Serialization/ASTReaderInternals.h b/clang/lib/Serialization/ASTReaderInternals.h
index 25a46ddabcb707..bdf541b1aa07ca 100644
--- a/clang/lib/Serialization/ASTReaderInternals.h
+++ b/clang/lib/Serialization/ASTReaderInternals.h
@@ -119,6 +119,86 @@ struct DeclContextLookupTable {
   MultiOnDiskHashTable<ASTDeclContextNameLookupTrait> Table;
 };
 
+/// Class that performs lookup to specialized decls.
+class SpecializedDeclLookupTrait {
+  ASTReader &Reader;
+  ModuleFile &F;
+
+public:
+  // Maximum number of lookup tables we allow before condensing the tables.
+  static const int MaxTables = 4;
+
+  /// The lookup result is a list of global declaration IDs.
+  using data_type = SmallVector<DeclID, 4>;
+
+  struct data_type_builder {
+    data_type &Data;
+    llvm::DenseSet<DeclID> Found;
+
+    data_type_builder(data_type &D) : Data(D) {}
+
+    void insert(DeclID ID) {
+      // Just use a linear scan unless we have more than a few IDs.
+      if (Found.empty() && !Data.empty()) {
+        if (Data.size() <= 4) {
+          for (auto I : Found)
+            if (I == ID)
+              return;
+          Data.push_back(ID);
+          return;
+        }
+
+        // Switch to tracking found IDs in the set.
+        Found.insert(Data.begin(), Data.end());
+      }
+
+      if (Found.insert(ID).second)
+        Data.push_back(ID);
+    }
+  };
+  using hash_value_type = unsigned;
+  using offset_type = unsigned;
+  using file_type = ModuleFile *;
+
+  using external_key_type = unsigned;
+  using internal_key_type = unsigned;
+
+  explicit SpecializedDeclLookupTrait(ASTReader &Reader, ModuleFile &F)
+      : Reader(Reader), F(F) {}
+
+  static bool EqualKey(const internal_key_type &a, const internal_key_type &b) {
+    return a == b;
+  }
+
+  static hash_value_type ComputeHash(const internal_key_type &Key) {
+    return Key;
+  }
+
+  static internal_key_type GetInternalKey(const external_key_type &Name) {
+    return Name;
+  }
+
+  static std::pair<unsigned, unsigned>
+  ReadKeyDataLength(const unsigned char *&d);
+
+  internal_key_type ReadKey(const unsigned char *d, unsigned);
+
+  void ReadDataInto(internal_key_type, const unsigned char *d, unsigned DataLen,
+                    data_type_builder &Val);
+
+  static void MergeDataInto(const data_type &From, data_type_builder &To) {
+    To.Data.reserve(To.Data.size() + From.size());
+    for (DeclID ID : From)
+      To.insert(ID);
+  }
+
+  file_type ReadFileRef(const unsigned char *&d);
+};
+
+struct SpecializedDeclsLookupTable {
+  MultiOnDiskHashTable<SpecializedDeclLookupTrait> Table;
+};
+
 /// Base class for the trait describing the on-disk hash table for the
 /// identifiers in an AST file.
 ///
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 78939bfd533ffa..162f2973607fc2 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -29,6 +29,7 @@
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/LambdaCapture.h"
 #include "clang/AST/NestedNameSpecifier.h"
+#include "clang/AST/ODRHash.h"
 #include "clang/AST/OpenMPClause.h"
 #include "clang/AST/RawCommentList.h"
 #include "clang/AST/TemplateName.h"
@@ -3924,6 +3925,152 @@ class ASTDeclContextNameLookupTrait {
 
 } // namespace
 
+namespace {
+class SpecializedDeclLookupTrait {
+  ASTWriter &Writer;
+  llvm::SmallVector<DeclID, 64> DeclIDs;
+
+public:
+  using key_type = unsigned;
+  using key_type_ref = key_type;
+
+  /// A start and end index into DeclIDs, representing a sequence of decls.
+  using data_type = std::pair<unsigned, unsigned>;
+  using data_type_ref = const data_type &;
+
+  using hash_value_type = unsigned;
+  using offset_type = unsigned;
+
+  explicit SpecializedDeclLookupTrait(ASTWriter &Writer) : Writer(Writer) {}
+
+  template <typename Col> data_type getData(Col &&C) {
+    unsigned Start = DeclIDs.size();
+    for (auto *D : C)
+      DeclIDs.push_back(Writer.GetDeclRef(getDeclForLocalLookup(
+          Writer.getLangOpts(), const_cast<NamedDecl *>(D))));
+    return std::make_pair(Start, DeclIDs.size());
+  }
+
+  data_type
+  ImportData(const reader::SpecializedDeclLookupTrait::data_type &FromReader) {
+    unsigned Start = DeclIDs.size();
+    for (auto ID : FromReader)
+      DeclIDs.push_back(ID);
+    return std::make_pair(Start, DeclIDs.size());
+  }
+
+  static bool EqualKey(key_type_ref a, key_type_ref b) { return a == b; }
+
+  hash_value_type ComputeHash(key_type Name) { return Name; }
+
+  void EmitFileRef(raw_ostream &Out, ModuleFile *F) const {
+    assert(Writer.hasChain() &&
+           "have reference to loaded module file but no chain?");
+
+    using namespace llvm::support;
+    endian::write<uint32_t>(Out, Writer.getChain()->getModuleFileID(F),
+                            llvm::endianness::little);
+  }
+
+  std::pair<unsigned, unsigned> EmitKeyDataLength(raw_ostream &Out,
+                                                  key_type HashValue,
+                                                  data_type_ref Lookup) {
+    // 4 bytes for each slot.
+    unsigned KeyLen = 4;
+    unsigned DataLen = 4 * (Lookup.second - Lookup.first);
+
+    return emitULEBKeyDataLength(KeyLen, DataLen, Out);
+  }
+
+  void EmitKey(raw_ostream &Out, key_type HashValue, unsigned) {
+    using namespace llvm::support;
+
+    endian::Writer LE(Out, llvm::endianness::little);
+    LE.write<uint32_t>(HashValue);
+  }
+
+  void EmitData(raw_ostream &Out, key_type_ref, data_type Lookup,
+                unsigned DataLen) {
+    using namespace llvm::support;
+
+    endian::Writer LE(Out, llvm::endianness::little);
+    uint64_t Start = Out.tell();
+    (void)Start;
+    for (unsigned I = Lookup.first, N = Lookup.second; I != N; ++I)
+      LE.write<uint32_t>(DeclIDs[I]);
+    assert(Out.tell() - Start == DataLen && "Data length is wrong");
+  }
+};
+
+unsigned GetTemplateArgsODRHash(ArrayRef<TemplateArgument> TemplateArgs) {
+  ODRHash Hasher;
+  Hasher.AddInteger(TemplateArgs.size());
+  for (const TemplateArgument &TemplateArg : TemplateArgs)
+    Hasher.AddTemplateArgument(TemplateArg);
+  return Hasher.CalculateHash();
+}
+
+unsigned CalculateODRHashForSpecs(const Decl *Spec) {
+  assert(!isa<ClassTemplatePartialSpecializationDecl>(Spec) &&
+         !isa<VarTemplatePartialSpecializationDecl>(Spec) &&
+         "We shouldn't see partial specializations here.");
+
+  if (auto *FD = dyn_cast<FunctionDecl>(Spec)) {
+    auto *FDInfo = FD->getTemplateSpecializationInfo();
+    assert(FDInfo);
+    return GetTemplateArgsODRHash(FDInfo->TemplateArguments->asArray());
+  }
+
+  if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(Spec))
+    return GetTemplateArgsODRHash(CTSD->getTemplateArgs().asArray());
+
+  if (auto *VTSD = dyn_cast<VarTemplateSpecializationDecl>(Spec))
+    return GetTemplateArgsODRHash(VTSD->getTemplateArgs().asArray());
+
+  llvm_unreachable("Unimaged specialization kind?");
+}
+} // namespace
+
+uint64_t ASTWriter::WriteSpecsLookupTable(
+    NamedDecl *D, llvm::SmallVectorImpl<const NamedDecl *> &Specs) {
+  assert(D->isFirstDecl());
+
+  // Create the on-disk hash table representation.
+  MultiOnDiskHashTableGenerator<reader::SpecializedDeclLookupTrait,
+                                SpecializedDeclLookupTrait>
+      Generator;
+  SpecializedDeclLookupTrait Trait(*this);
+
+  llvm::DenseMap<unsigned, llvm::SmallVector<const NamedDecl *, 4>> SpecsMaps;
+
+  for (auto *Spec : Specs) {
+    unsigned HashedValue = CalculateODRHashForSpecs(Spec);
+
+    auto Iter = SpecsMaps.find(HashedValue);
+    if (Iter == SpecsMaps.end())
+      Iter = SpecsMaps
+                 .try_emplace(HashedValue,
+                              llvm::SmallVector<const NamedDecl *, 4>())
+                 .first;
+
+    Iter->second.push_back(Spec);
+  }
+
+  for (auto Iter : SpecsMaps)
+    Generator.insert(Iter.first, Trait.getData(Iter.second), Trait);
+
+  uint64_t Offset = Stream.GetCurrentBitNo();
+
+  auto *Lookups = Chain ? Chain->getLoadedSpecLookupTables(D) : nullptr;
+  llvm::SmallString<4096> LookupTable;
+  Generator.emit(LookupTable, Trait, Lookups ? &Lookups->Table : nullptr);
+
+  RecordData::value_type Record[] = {DECL_SPECS};
+  Stream.EmitRecordWithBlob(DeclSpecsAbbrev, Record, LookupTable);
+
+  return Offset;
+}
+
 bool ASTWriter::isLookupResultExternal(StoredDeclsList &Result,
                                        DeclContext *DC) {
   return Result.hasExternalDecls() &&
@@ -5074,7 +5221,7 @@ ASTFileSignature ASTWriter::WriteASTCore(Sema &SemaRef, StringRef isysroot,
 
   // Keep writing types, declarations, and declaration update records
   // until we've emitted all of them.
-  Stream.EnterSubblock(DECLTYPES_BLOCK_ID, /*bits for abbreviations*/5);
+  Stream.EnterSubblock(DECLTYPES_BLOCK_ID, /*bits for abbreviations*/ 6);
   DeclTypesBlockStartOffset = Stream.GetCurrentBitNo();
   WriteTypeAbbrevs();
   WriteDeclAbbrevs();
diff --git a/clang/lib/Serialization/ASTWriterDecl.cpp b/clang/lib/Serialization/ASTWriterDecl.cpp
index 9e3299f0491848..756e4a03a6a3cb 100644
--- a/clang/lib/Serialization/ASTWriterDecl.cpp
+++ b/clang/lib/Serialization/ASTWriterDecl.cpp
@@ -207,22 +207,26 @@ namespace clang {
       return std::nullopt;
     }
 
-    template<typename DeclTy>
-    void AddTemplateSpecializations(DeclTy *D) {
+    template <typename DeclTy>
+    void AddTemplateSpecializations(
+        DeclTy *D, llvm::SmallVectorImpl<const NamedDecl *> &OptionalSpecs) {
       auto *Common = D->getCommonPtr();
 
       // If we have any lazy specializations, and the external AST source is
       // our chained AST reader, we can just write out the DeclIDs. Otherwise,
       // we need to resolve them to actual declarations.
       if (Writer.Chain != Writer.Context->getExternalSource() &&
-          Common->LazySpecializations) {
-        D->LoadLazySpecializations();
-        assert(!Common->LazySpecializations);
+          Common->ExternalSpecializations) {
+        D->loadExternalSpecializations();
+        assert(!Common->ExternalSpecializations);
       }
 
-      ArrayRef<DeclID> LazySpecializations;
-      if (auto *LS = Common->LazySpecializations)
-        LazySpecializations = llvm::ArrayRef(LS + 1, LS[0]);
+      for (auto &Entry : Common->Specializations)
+        OptionalSpecs.push_back(getSpecializationDecl(Entry));
+
+      ArrayRef<DeclID> ExternalSpecializations;
+      if (auto *LS = Common->ExternalSpecializations)
+        ExternalSpecializations = llvm::ArrayRef(LS + 1, LS[0]);
 
       // Add a slot to the record for the number of specializations.
       unsigned I = Record.size();
@@ -230,17 +234,20 @@ namespace clang {
 
       // AddFirstDeclFromEachModule might trigger deserialization, invalidating
       // *Specializations iterators.
-      llvm::SmallVector<const Decl*, 16> Specs;
-      for (auto &Entry : Common->Specializations)
-        Specs.push_back(getSpecializationDecl(Entry));
+      //
+      // We need to load all the partial specializations at once if the template
+      // required. Since we can't know if a partial specializations will be
+      // needed before resolving a request to instantiate the template.
+      llvm::SmallVector<const Decl *, 16> PartialSpecs;
       for (auto &Entry : getPartialSpecializations(Common))
-        Specs.push_back(getSpecializationDecl(Entry));
+        PartialSpecs.push_back(getSpecializationDecl(Entry));
 
-      for (auto *D : Specs) {
+      for (auto *D : PartialSpecs) {
         assert(D->isCanonicalDecl() && "non-canonical decl in set");
         AddFirstDeclFromEachModule(D, /*IncludeLocal*/true);
       }
-      Record.append(LazySpecializations.begin(), LazySpecializations.end());
+      Record.append(ExternalSpecializations.begin(),
+                    ExternalSpecializations.end());
 
       // Update the size entry we added earlier.
       Record[I] = Record.size() - I - 1;
@@ -1670,8 +1677,15 @@ void ASTDeclWriter::VisitRedeclarableTemplateDecl(RedeclarableTemplateDecl *D) {
 void ASTDeclWriter::VisitClassTemplateDecl(ClassTemplateDecl *D) {
   VisitRedeclarableTemplateDecl(D);
 
-  if (D->isFirstDecl())
-    AddTemplateSpecializations(D);
+  if (D->isFirstDecl()) {
+    llvm::SmallVector<const NamedDecl *, 16> OptionalSpecs;
+    AddTemplateSpecializations(D, OptionalSpecs);
+    if (!OptionalSpecs.empty()) {
+      Record.push_back(1);
+      Record.AddOffset(Writer.WriteSpecsLookupTable(D, OptionalSpecs));
+    } else
+      Record.push_back(0);
+  }
   Code = serialization::DECL_CLASS_TEMPLATE;
 }
 
@@ -1730,8 +1744,16 @@ void ASTDeclWriter::VisitClassTemplatePartialSpecializationDecl(
 void ASTDeclWriter::VisitVarTemplateDecl(VarTemplateDecl *D) {
   VisitRedeclarableTemplateDecl(D);
 
-  if (D->isFirstDecl())
-    AddTemplateSpecializations(D);
+  if (D->isFirstDecl()) {
+    llvm::SmallVector<const NamedDecl *, 16> OptionalSpecs;
+    AddTemplateSpecializations(D, OptionalSpecs);
+    if (!OptionalSpecs.empty()) {
+      Record.push_back(1);
+      Record.AddOffset(Writer.WriteSpecsLookupTable(D, OptionalSpecs));
+    } else
+      Record.push_back(0);
+  }
+
   Code = serialization::DECL_VAR_TEMPLATE;
 }
 
@@ -1791,8 +1813,16 @@ void ASTDeclWriter::VisitVarTemplatePartialSpecializationDecl(
 void ASTDeclWriter::VisitFunctionTemplateDecl(FunctionTemplateDecl *D) {
   VisitRedeclarableTemplateDecl(D);
 
-  if (D->isFirstDecl())
-    AddTemplateSpecializations(D);
+  if (D->isFirstDecl()) {
+    llvm::SmallVector<const NamedDecl *, 16> OptionalSpecs;
+    AddTemplateSpecializations(D, OptionalSpecs);
+    if (!OptionalSpecs.empty()) {
+      Record.push_back(1);
+      Record.AddOffset(Writer.WriteSpecsLookupTable(D, OptionalSpecs));
+    } else
+      Record.push_back(0);
+  }
+
   Code = serialization::DECL_FUNCTION_TEMPLATE;
 }
 
@@ -2657,6 +2687,11 @@ void ASTWriter::WriteDeclAbbrevs() {
   Abv->Add(BitCodeAbbrevOp(serialization::DECL_CONTEXT_VISIBLE));
   Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Blob));
   DeclContextVisibleLookupAbbrev = Stream.EmitAbbrev(std::move(Abv));
+
+  Abv = std::make_shared<BitCodeAbbrev>();
+  Abv->Add(BitCodeAbbrevOp(serialization::DECL_SPECS));
+  Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Blob));
+  DeclSpecsAbbrev = Stream.EmitAbbrev(std::move(Abv));
 }
 
 /// isRequiredDecl - Check if this is a "required" Decl, which must be seen by
diff --git a/clang/test/Modules/odr_hash.cpp b/clang/test/Modules/odr_hash.cpp
index 220ef767df849a..4483cf21e5c623 100644
--- a/clang/test/Modules/odr_hash.cpp
+++ b/clang/test/Modules/odr_hash.cpp
@@ -2897,7 +2897,7 @@ struct S5 {
 };
 #else
 S5 s5;
-// expected-error at second.h:* {{'PointersAndReferences::S5::x' from module 'SecondModule' is not present in definition of 'PointersAndReferences::S5' in module 'FirstModule'}}
+// expected-error at first.h:* {{'PointersAndReferences::S5::x' from module 'FirstModule' is not present in definition of 'PointersAndReferences::S5' in module 'SecondModule'}}
 // expected-note at first.h:* {{declaration of 'x' does not match}}
 #endif
 
@@ -3834,7 +3834,7 @@ struct Valid {
 #else
 Invalid::L2<1>::L3<1> invalid;
 // expected-error at second.h:* {{'Types::InjectedClassName::Invalid::L2::L3::x' from module 'SecondModule' is not present in definition of 'L3<>' in module 'FirstModule'}}
-// expected-note at first.h:* {{declaration of 'x' does not match}}
+// expected-note at second.h:* {{declaration of 'x' does not match}}
 Valid::L2<1>::L3<1> valid;
 #endif
 }  // namespace InjectedClassName
diff --git a/clang/test/Modules/static-member-in-templates.cppm b/clang/test/Modules/static-member-in-templates.cppm
new file mode 100644
index 00000000000000..c419eee8097a88
--- /dev/null
+++ b/clang/test/Modules/static-member-in-templates.cppm
@@ -0,0 +1,52 @@
+// RUN: rm -rf %t
+// RUN: mkdir -p %t
+// RUN: split-file %s %t
+//
+// RUN: %clang_cc1 -std=c++20 %t/A.cppm -emit-module-interface -o %t/A.pcm
+// RUN: %clang_cc1 -std=c++20 %t/B.cppm -fmodule-file=A=%t/A.pcm -emit-module-interface -o %t/B.pcm
+// RUN: %clang_cc1 -std=c++20 %t/use.cpp -fprebuilt-module-path=%t -fsyntax-only -verify
+
+//--- foo.h
+
+class DefaultAlloc {
+public:
+    using size_type = unsigned;
+};
+
+template <class T, class Alloc = DefaultAlloc>
+class A {
+public:
+    using size_type = Alloc::size_type;
+    static const size_type value = ~0x0;
+
+    size_type
+    getValue() {
+        return value;
+    }
+};
+
+template <class T, class Alloc>
+const typename A<T, Alloc>::size_type value;
+
+typedef A<char> AC;
+
+//--- A.cppm
+module;
+#include "foo.h"
+export module A;
+export using ::AC;
+
+//--- B.cppm
+module;
+#include "foo.h"
+export module B;
+import A;
+
+//--- use.cpp
+// expected-no-diagnostics
+import A;
+import B;
+int use() {
+    AC a;
+    return a.getValue();
+}
diff --git a/clang/unittests/Serialization/CMakeLists.txt b/clang/unittests/Serialization/CMakeLists.txt
index 10d7de970c643d..6276bbb6d0cb4b 100644
--- a/clang/unittests/Serialization/CMakeLists.txt
+++ b/clang/unittests/Serialization/CMakeLists.txt
@@ -8,6 +8,7 @@ set(LLVM_LINK_COMPONENTS
 add_clang_unittest(SerializationTests
   ForceCheckFileInputTest.cpp
   InMemoryModuleCacheTest.cpp
+  LoadSpecLazily.cpp
   ModuleCacheTest.cpp
   NoCommentsTest.cpp
   SourceLocationEncodingTest.cpp
diff --git a/clang/unittests/Serialization/LoadSpecLazily.cpp b/clang/unittests/Serialization/LoadSpecLazily.cpp
new file mode 100644
index 00000000000000..03f3ff3f786555
--- /dev/null
+++ b/clang/unittests/Serialization/LoadSpecLazily.cpp
@@ -0,0 +1,159 @@
+//== unittests/Serialization/LoadSpecLazily.cpp ----------------------========//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Frontend/FrontendAction.h"
+#include "clang/Frontend/FrontendActions.h"
+#include "clang/Parse/ParseAST.h"
+#include "clang/Serialization/ASTDeserializationListener.h"
+#include "clang/Tooling/Tooling.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using namespace clang;
+using namespace clang::tooling;
+
+namespace {
+
+class LoadSpecLazilyTest : public ::testing::Test {
+  void SetUp() override {
+    ASSERT_FALSE(
+        sys::fs::createUniqueDirectory("load-spec-lazily-test", TestDir));
+  }
+
+  void TearDown() override { sys::fs::remove_directories(TestDir); }
+
+public:
+  SmallString<256> TestDir;
+
+  void addFile(StringRef Path, StringRef Contents) {
+    ASSERT_FALSE(sys::path::is_absolute(Path));
+
+    SmallString<256> AbsPath(TestDir);
+    sys::path::append(AbsPath, Path);
+
+    ASSERT_FALSE(
+        sys::fs::create_directories(llvm::sys::path::parent_path(AbsPath)));
+
+    std::error_code EC;
+    llvm::raw_fd_ostream OS(AbsPath, EC);
+    ASSERT_FALSE(EC);
+    OS << Contents;
+  }
+
+  std::string GenerateModuleInterface(StringRef ModuleName,
+                                      StringRef Contents) {
+    std::string FileName = llvm::Twine(ModuleName + ".cppm").str();
+    addFile(FileName, Contents);
+
+    IntrusiveRefCntPtr<DiagnosticsEngine> Diags =
+        CompilerInstance::createDiagnostics(new DiagnosticOptions());
+    CreateInvocationOptions CIOpts;
+    CIOpts.Diags = Diags;
+    CIOpts.VFS = llvm::vfs::createPhysicalFileSystem();
+
+    std::string CacheBMIPath =
+        llvm::Twine(TestDir + "/" + ModuleName + " .pcm").str();
+    std::string PrebuiltModulePath =
+        "-fprebuilt-module-path=" + TestDir.str().str();
+    const char *Args[] = {"clang++",
+                          "-std=c++20",
+                          "--precompile",
+                          PrebuiltModulePath.c_str(),
+                          "-working-directory",
+                          TestDir.c_str(),
+                          "-I",
+                          TestDir.c_str(),
+                          FileName.c_str(),
+                          "-o",
+                          CacheBMIPath.c_str()};
+    std::shared_ptr<CompilerInvocation> Invocation =
+        createInvocation(Args, CIOpts);
+    EXPECT_TRUE(Invocation);
+
+    CompilerInstance Instance;
+    Instance.setDiagnostics(Diags.get());
+    Instance.setInvocation(Invocation);
+    GenerateModuleInterfaceAction Action;
+    EXPECT_TRUE(Instance.ExecuteAction(Action));
+    EXPECT_FALSE(Diags->hasErrorOccurred());
+
+    return CacheBMIPath;
+  }
+};
+
+class DeclsReaderListener : public ASTDeserializationListener {
+public:
+  void DeclRead(serialization::DeclID ID, const Decl *D) override {
+    auto *ND = dyn_cast<NamedDecl>(D);
+    if (!ND)
+      return;
+
+    EXPECT_FALSE(ND->getName().contains(ForbiddenName));
+  }
+
+  DeclsReaderListener(StringRef ForbiddenName) : ForbiddenName(ForbiddenName) {}
+
+  StringRef ForbiddenName;
+};
+
+class LoadSpecLazilyConsumer : public ASTConsumer {
+  DeclsReaderListener Listener;
+
+public:
+  LoadSpecLazilyConsumer(StringRef ForbiddenName) : Listener(ForbiddenName) {}
+
+  ASTDeserializationListener *GetASTDeserializationListener() override {
+    return &Listener;
+  }
+};
+
+class CheckLoadSpecLazilyAction : public ASTFrontendAction {
+  StringRef ForbiddenName;
+
+public:
+  std::unique_ptr<ASTConsumer>
+  CreateASTConsumer(CompilerInstance &CI, StringRef /*Unused*/) override {
+    return std::make_unique<LoadSpecLazilyConsumer>(ForbiddenName);
+  }
+
+  CheckLoadSpecLazilyAction(StringRef ForbiddenName)
+      : ForbiddenName(ForbiddenName) {}
+};
+
+TEST_F(LoadSpecLazilyTest, BasicTest) {
+  GenerateModuleInterface("M", R"cpp(
+export module M;
+export template <class T>
+class A {};
+
+export class ShouldNotBeLoaded {};
+
+export class Temp {
+   A<ShouldNotBeLoaded> AS;
+};
+  )cpp");
+
+  const char *test_file_contents = R"cpp(
+import M;
+A<int> a;
+  )cpp";
+  std::string DepArg = "-fprebuilt-module-path=" + TestDir.str().str();
+  EXPECT_TRUE(runToolOnCodeWithArgs(
+      std::make_unique<CheckLoadSpecLazilyAction>("ShouldNotBeLoaded"),
+      test_file_contents,
+      {
+          "-std=c++20",
+          DepArg.c_str(),
+          "-I",
+          TestDir.c_str(),
+      },
+      "test.cpp"));
+}
+
+} // namespace



More information about the cfe-commits mailing list